feat: batch update with laisky's one-api

This commit is contained in:
Laisky.Cai
2025-03-20 10:43:14 +00:00
parent 761ee32d19
commit b2d6aa783b
35 changed files with 479 additions and 108 deletions

View File

@@ -39,16 +39,24 @@ func (a *Adaptor) Init(meta *meta.Meta) {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.ChannelType {
case channeltype.Azure:
defaultVersion := meta.Config.APIVersion
// https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning?tabs=python#api--feature-support
if strings.HasPrefix(meta.ActualModelName, "o1") ||
strings.HasPrefix(meta.ActualModelName, "o3") {
defaultVersion = "2024-12-01-preview"
}
if meta.Mode == relaymode.ImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion)
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, defaultVersion)
return fullRequestURL, nil
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion)
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, defaultVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := meta.ActualModelName
model_ = strings.Replace(model_, ".", "", -1)
@@ -160,7 +168,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return request, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -191,6 +199,8 @@ func (a *Adaptor) DoResponse(c *gin.Context,
switch meta.Mode {
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
case relaymode.ImagesEdits:
err, _ = ImagesEditsHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}

View File

@@ -7,11 +7,10 @@ var ModelList = []string{
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-2024-11-20",
"chatgpt-4o-latest",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "chatgpt-4o-latest",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4o-mini-audio-preview", "gpt-4o-mini-audio-preview-2024-12-17",
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2024-10-01",
"gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",

View File

@@ -3,12 +3,30 @@ package openai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
)
// ImagesEditsHandler just copy response body to client
//
// https://platform.openai.com/docs/api-reference/images/createEdit
func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
c.Writer.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
defer resp.Body.Close()
return nil, nil
}
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"encoding/json"
"io"
"math"
"net/http"
"strings"
@@ -13,6 +14,7 @@ import (
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
@@ -23,144 +25,300 @@ const (
dataPrefixLength = len(dataPrefix)
)
// StreamHandler processes streaming responses from OpenAI API
// It handles incremental content delivery and accumulates the final response text
// Returns error (if any), accumulated response text, and token usage information
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
// Initialize accumulators for the response
responseText := ""
reasoningText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
var usage *model.Usage
// Set up scanner for reading the stream line by line
scanner := bufio.NewScanner(resp.Body)
buffer := make([]byte, 256*1024) // 256KB buffer for large messages
scanner.Buffer(buffer, len(buffer))
scanner.Split(bufio.ScanLines)
// Set response headers for SSE
common.SetEventStreamHeaders(c)
doneRendered := false
// Process each line from the stream
for scanner.Scan() {
data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format
continue
data := NormalizeDataLine(scanner.Text())
// logger.Debugf(c.Request.Context(), "stream response: %s", data)
// Skip lines that don't match expected format
if len(data) < dataPrefixLength {
continue // Ignore blank line or wrong format
}
// Verify line starts with expected prefix
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue
}
// Check for stream termination
if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data)
doneRendered = true
continue
}
// Process based on relay mode
switch relayMode {
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
// Parse the JSON response
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
logger.Errorf(c.Request.Context(), "unmarshalling stream data %q got %+v", data, err)
render.StringData(c, data) // Pass raw data to client if parsing fails
continue
}
// Skip empty choices (Azure specific behavior)
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
// but for empty choice and no usage, we should not pass it to client, this is for azure
continue // just ignore empty choice
continue
}
render.StringData(c, data)
// Process each choice in the response
for _, choice := range streamResponse.Choices {
if choice.Delta.Reasoning != nil {
reasoningText += *choice.Delta.Reasoning
}
if choice.Delta.ReasoningContent != nil {
reasoningText += *choice.Delta.ReasoningContent
// Extract reasoning content from different possible fields
currentReasoningChunk := extractReasoningContent(&choice.Delta)
// Update accumulated reasoning text
if currentReasoningChunk != "" {
reasoningText += currentReasoningChunk
}
// Set the reasoning content in the format requested by client
choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk)
// Accumulate response content
responseText += conv.AsString(choice.Delta.Content)
}
// Send the processed data to the client
render.StringData(c, data)
// Update usage information if available
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case relaymode.Completions:
// Send the data immediately for Completions mode
render.StringData(c, data)
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
// Accumulate text from all choices
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
// Check for scanner errors
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
// Ensure stream termination is sent to client
if !doneRendered {
render.Done(c)
}
err := resp.Body.Close()
if err != nil {
// Clean up resources
if err := resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
}
// Return the complete response text (reasoning + content) and usage
return nil, reasoningText + responseText, usage
}
// Helper function to extract reasoning content from message delta
func extractReasoningContent(delta *model.Message) string {
content := ""
// Extract reasoning from different possible fields
if delta.Reasoning != nil {
content += *delta.Reasoning
delta.Reasoning = nil
}
if delta.ReasoningContent != nil {
content += *delta.ReasoningContent
delta.ReasoningContent = nil
}
return content
}
// Handler processes non-streaming responses from OpenAI API
// Returns error (if any) and token usage information
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
var textResponse SlimTextResponse
// Read the entire response body
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
// Close the original response body
if err = resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
// Parse the response JSON
var textResponse SlimTextResponse
if err = json.Unmarshal(responseBody, &textResponse); err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// Check for API errors
if textResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the HTTPClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
// Process reasoning content in each choice
for _, msg := range textResponse.Choices {
reasoningContent := processReasoningContent(&msg)
// Set reasoning in requested format if content exists
if reasoningContent != "" {
msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent)
}
}
// Reset response body for forwarding to client
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
// Forward all response headers (not just first value of each)
for k, values := range resp.Header {
for _, v := range values {
c.Writer.Header().Add(k, v)
}
}
// Set response status and copy body to client
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
if _, err = io.Copy(c.Writer, resp.Body); err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
// Close the reset body
if err = resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Usage.TotalTokens == 0 ||
(textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
// Calculate token usage if not provided by API
calculateTokenUsage(&textResponse, promptTokens, modelName)
return nil, &textResponse.Usage
}
// processReasoningContent is a helper function to extract and process reasoning content from the message
func processReasoningContent(msg *TextResponseChoice) string {
var reasoningContent string
// Check different locations for reasoning content
switch {
case msg.Reasoning != nil:
reasoningContent = *msg.Reasoning
msg.Reasoning = nil
case msg.ReasoningContent != nil:
reasoningContent = *msg.ReasoningContent
msg.ReasoningContent = nil
case msg.Message.Reasoning != nil:
reasoningContent = *msg.Message.Reasoning
msg.Message.Reasoning = nil
case msg.Message.ReasoningContent != nil:
reasoningContent = *msg.Message.ReasoningContent
msg.Message.ReasoningContent = nil
}
return reasoningContent
}
// Helper function to calculate token usage
func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) {
// Calculate tokens if not provided by the API
if response.Usage.TotalTokens == 0 ||
(response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range textResponse.Choices {
for _, choice := range response.Choices {
// Count content tokens
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
// Count reasoning tokens in all possible locations
if choice.Message.Reasoning != nil {
completionTokens += CountToken(*choice.Message.Reasoning)
}
if choice.Message.ReasoningContent != nil {
completionTokens += CountToken(*choice.Message.ReasoningContent)
}
if choice.Reasoning != nil {
completionTokens += CountToken(*choice.Reasoning)
}
if choice.ReasoningContent != nil {
completionTokens += CountToken(*choice.ReasoningContent)
}
}
textResponse.Usage = model.Usage{
// Set usage values
response.Usage = model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
} else if hasAudioTokens(response) {
// Handle audio tokens conversion
calculateAudioTokens(response, modelName)
}
}
// Helper function to check if response has audio tokens
func hasAudioTokens(response *SlimTextResponse) bool {
return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) ||
(response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0)
}
// Helper function to calculate audio token usage
func calculateAudioTokens(response *SlimTextResponse, modelName string) {
// Convert audio tokens for prompt
if response.PromptTokensDetails != nil {
response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens +
int(math.Ceil(
float64(response.PromptTokensDetails.AudioTokens)*
ratio.GetAudioPromptRatio(modelName),
))
}
return nil, &textResponse.Usage
// Convert audio tokens for completion
if response.CompletionTokensDetails != nil {
response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens +
int(math.Ceil(
float64(response.CompletionTokensDetails.AudioTokens)*
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
))
}
// Calculate total tokens
response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
}

View File

@@ -1,6 +1,10 @@
package openai
import "github.com/songquanpeng/one-api/relay/model"
import (
"mime/multipart"
"github.com/songquanpeng/one-api/relay/model"
)
type TextContent struct {
Type string `json:"type,omitempty"`
@@ -71,6 +75,24 @@ type TextToSpeechRequest struct {
ResponseFormat string `json:"response_format"`
}
type AudioTranscriptionRequest struct {
File *multipart.FileHeader `form:"file" binding:"required"`
Model string `form:"model" binding:"required"`
Language string `form:"language"`
Prompt string `form:"prompt"`
ReponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
Temperature float64 `form:"temperature"`
TimestampGranularity []string `form:"timestamp_granularity"`
}
type AudioTranslationRequest struct {
File *multipart.FileHeader `form:"file" binding:"required"`
Model string `form:"model" binding:"required"`
Prompt string `form:"prompt"`
ResponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
Temperature float64 `form:"temperature"`
}
type UsageOrResponseText struct {
*model.Usage
ResponseText string
@@ -110,12 +132,14 @@ type EmbeddingResponse struct {
model.Usage `json:"usage"`
}
// ImageData represents an image in the response
type ImageData struct {
Url string `json:"url,omitempty"`
B64Json string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageResponse represents the response structure for image generations
type ImageResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`

View File

@@ -1,16 +1,20 @@
package openai
import (
"errors"
"bytes"
"context"
"encoding/base64"
"fmt"
"math"
"strings"
"github.com/pkg/errors"
"github.com/pkoukk/tiktoken-go"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/billing/ratio"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -73,8 +77,10 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
func CountTokenMessages(messages []model.Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// CountTokenMessages counts the number of tokens in a list of messages.
func CountTokenMessages(ctx context.Context,
messages []model.Message, actualModel string) int {
tokenEncoder := getTokenEncoder(actualModel)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
@@ -82,47 +88,54 @@ func CountTokenMessages(messages []model.Message, model string) int {
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
if actualModel == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
var totalAudioTokens float64
for _, message := range messages {
tokenNum += tokensPerMessage
switch v := message.Content.(type) {
case string:
tokenNum += getTokenNum(tokenEncoder, v)
case []any:
for _, it := range v {
m := it.(map[string]any)
switch m["type"] {
case "text":
if textValue, ok := m["text"]; ok {
if textString, ok := textValue.(string); ok {
tokenNum += getTokenNum(tokenEncoder, textString)
}
}
case "image_url":
imageUrl, ok := m["image_url"].(map[string]any)
if ok {
url := imageUrl["url"].(string)
detail := ""
if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string)
}
imageTokens, err := countImageTokens(url, detail, model)
if err != nil {
logger.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
}
contents := message.ParseContent()
for _, content := range contents {
switch content.Type {
case model.ContentTypeText:
if content.Text != nil {
tokenNum += getTokenNum(tokenEncoder, *content.Text)
}
case model.ContentTypeImageURL:
imageTokens, err := countImageTokens(
content.ImageURL.Url,
content.ImageURL.Detail,
actualModel)
if err != nil {
logger.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
case model.ContentTypeInputAudio:
audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data)
if err != nil {
logger.SysError("error decoding audio data: " + err.Error())
}
audioTokens, err := helper.GetAudioTokens(ctx,
bytes.NewReader(audioData),
ratio.GetAudioPromptTokensPerSecond(actualModel))
if err != nil {
logger.SysError("error counting audio tokens: " + err.Error())
} else {
totalAudioTokens += audioTokens
}
}
}
tokenNum += int(math.Ceil(totalAudioTokens))
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName

View File

@@ -3,6 +3,7 @@ package openai
import (
"context"
"fmt"
"strings"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
@@ -21,3 +22,11 @@ func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatus
StatusCode: statusCode,
}
}
func NormalizeDataLine(data string) string {
if strings.HasPrefix(data, "data:") {
content := strings.TrimLeft(data[len("data:"):], " ")
return "data: " + content
}
return data
}