diff --git a/common/helper/audio.go b/common/helper/audio.go index 9db62f42..e5689904 100644 --- a/common/helper/audio.go +++ b/common/helper/audio.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "math" "os" "os/exec" "strconv" @@ -13,7 +14,11 @@ import ( // SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. func SaveTmpFile(filename string, data io.Reader) (string, error) { - f, err := os.CreateTemp(os.TempDir(), filename) + if data == nil { + return "", errors.New("data is nil") + } + + f, err := os.CreateTemp("", "*-"+filename) if err != nil { return "", errors.Wrapf(err, "failed to create temporary file %s", filename) } @@ -27,6 +32,22 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) { return f.Name(), nil } +// GetAudioTokens returns the number of tokens in an audio file. +func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) { + filename, err := SaveTmpFile("audio", audio) + if err != nil { + return 0, errors.Wrap(err, "failed to save audio to temporary file") + } + defer os.Remove(filename) + + duration, err := GetAudioDuration(ctx, filename) + if err != nil { + return 0, errors.Wrap(err, "failed to get audio tokens") + } + + return int(math.Ceil(duration)) * tokensPerSecond, nil +} + // GetAudioDuration returns the duration of an audio file in seconds. func GetAudioDuration(ctx context.Context, filename string) (float64, error) { // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} @@ -36,5 +57,7 @@ func GetAudioDuration(ctx context.Context, filename string) (float64, error) { return 0, errors.Wrap(err, "failed to get audio duration") } + // Actually gpt-4-audio calculates tokens with 0.1s precision, + // while whisper calculates tokens with 1s precision return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) } diff --git a/common/helper/audio_test.go b/common/helper/audio_test.go index 90f334a3..15f55bbb 100644 --- a/common/helper/audio_test.go +++ b/common/helper/audio_test.go @@ -35,3 +35,21 @@ func TestGetAudioDuration(t *testing.T) { require.Error(t, err) }) } + +func TestGetAudioTokens(t *testing.T) { + t.Run("should return correct tokens for a valid audio file", func(t *testing.T) { + // download test audio file + resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a") + require.NoError(t, err) + defer resp.Body.Close() + + tokens, err := GetAudioTokens(context.Background(), resp.Body, 50) + require.NoError(t, err) + require.Equal(t, tokens, 200) + }) + + t.Run("should return an error for a non-existent file", func(t *testing.T) { + _, err := GetAudioTokens(context.Background(), nil, 1) + require.Error(t, err) + }) +} diff --git a/middleware/recover.go b/middleware/recover.go index cfc3f827..a690c77b 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -14,11 +14,11 @@ func RelayPanicRecover() gin.HandlerFunc { defer func() { if err := recover(); err != nil { ctx := c.Request.Context() - logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) - logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) - logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) + logger.Errorf(ctx, "panic detected: %v", err) + logger.Errorf(ctx, "stacktrace from panic: %s", string(debug.Stack())) + logger.Errorf(ctx, "request: %s %s", c.Request.Method, c.Request.URL.Path) body, _ := common.GetRequestBody(c) - logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) + logger.Errorf(ctx, "request body: %s", string(body)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 6946e402..21966262 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -82,6 +82,27 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } request.StreamOptions.IncludeUsage = true } + + // o1/o1-mini/o1-preview do not support system prompt and max_tokens + if strings.HasPrefix(request.Model, "o1") { + request.MaxTokens = 0 + request.Messages = func(raw []model.Message) (filtered []model.Message) { + for i := range raw { + if raw[i].Role != "system" { + filtered = append(filtered, raw[i]) + } + } + + return + }(request.Messages) + } + + if request.Stream && strings.HasPrefix(request.Model, "gpt-4o-audio") { + // TODO: Since it is not clear how to implement billing in stream mode, + // it is temporarily not supported + return nil, errors.New("stream mode is not supported for gpt-4o-audio") + } + return request, nil } diff --git a/relay/adaptor/openai/constants.go b/relay/adaptor/openai/constants.go index 8a643bc6..2c34284f 100644 --- a/relay/adaptor/openai/constants.go +++ b/relay/adaptor/openai/constants.go @@ -12,6 +12,7 @@ var ModelList = []string{ "gpt-4o-2024-11-20", "chatgpt-4o-latest", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", + "gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2024-10-01", "gpt-4-vision-preview", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 97080738..095a6adb 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -5,15 +5,16 @@ import ( "bytes" "encoding/json" "io" + "math" "net/http" "strings" - "github.com/songquanpeng/one-api/common/render" - "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" ) @@ -96,6 +97,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E return nil, responseText, usage } +// Handler handles the non-stream response from OpenAI API func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { var textResponse SlimTextResponse responseBody, err := io.ReadAll(resp.Body) @@ -146,6 +148,22 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, } + } else { + // Convert the more expensive audio tokens to uniformly priced text tokens + textResponse.Usage.PromptTokens = textResponse.CompletionTokensDetails.TextTokens + + int(math.Ceil( + float64(textResponse.CompletionTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName), + )) + textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens + + int(math.Ceil( + float64(textResponse.CompletionTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName)* + ratio.GetAudioCompletionRatio(modelName), + )) + textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens + + textResponse.Usage.CompletionTokens } + return nil, &textResponse.Usage } diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go index 7c8468b9..1287e44b 100644 --- a/relay/adaptor/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -1,16 +1,22 @@ package openai import ( - "errors" + "bytes" + "context" + "encoding/base64" "fmt" - "github.com/pkoukk/tiktoken-go" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/image" - "github.com/songquanpeng/one-api/common/logger" - billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" - "github.com/songquanpeng/one-api/relay/model" "math" "strings" + + "github.com/pkg/errors" + "github.com/pkoukk/tiktoken-go" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/billing/ratio" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/model" ) // tokenEncoderMap won't grow after initialization @@ -70,8 +76,9 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func CountTokenMessages(messages []model.Message, model string) int { - tokenEncoder := getTokenEncoder(model) +func CountTokenMessages(ctx context.Context, + messages []model.Message, actualModel string) int { + tokenEncoder := getTokenEncoder(actualModel) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb // https://github.com/pkoukk/tiktoken-go/issues/6 @@ -79,7 +86,7 @@ func CountTokenMessages(messages []model.Message, model string) int { // Every message follows <|start|>{role/name}\n{content}<|end|>\n var tokensPerMessage int var tokensPerName int - if model == "gpt-3.5-turbo-0301" { + if actualModel == "gpt-3.5-turbo-0301" { tokensPerMessage = 4 tokensPerName = -1 // If there's a name, the role is omitted } else { @@ -89,37 +96,38 @@ func CountTokenMessages(messages []model.Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - switch v := message.Content.(type) { - case string: - tokenNum += getTokenNum(tokenEncoder, v) - case []any: - for _, it := range v { - m := it.(map[string]any) - switch m["type"] { - case "text": - if textValue, ok := m["text"]; ok { - if textString, ok := textValue.(string); ok { - tokenNum += getTokenNum(tokenEncoder, textString) - } - } - case "image_url": - imageUrl, ok := m["image_url"].(map[string]any) - if ok { - url := imageUrl["url"].(string) - detail := "" - if imageUrl["detail"] != nil { - detail = imageUrl["detail"].(string) - } - imageTokens, err := countImageTokens(url, detail, model) - if err != nil { - logger.SysError("error counting image tokens: " + err.Error()) - } else { - tokenNum += imageTokens - } - } + contents := message.ParseContent() + for _, content := range contents { + switch content.Type { + case model.ContentTypeText: + tokenNum += getTokenNum(tokenEncoder, content.Text) + case model.ContentTypeImageURL: + imageTokens, err := countImageTokens( + content.ImageURL.Url, + content.ImageURL.Detail, + actualModel) + if err != nil { + logger.SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + case model.ContentTypeInputAudio: + audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data) + if err != nil { + logger.SysError("error decoding audio data: " + err.Error()) + } + + tokens, err := helper.GetAudioTokens(ctx, + bytes.NewReader(audioData), + ratio.GetAudioPromptTokensPerSecond(actualModel)) + if err != nil { + logger.SysError("error counting audio tokens: " + err.Error()) + } else { + tokenNum += tokens } } } + tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index d1720a99..14a23a51 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -3,6 +3,7 @@ package ratio import ( "encoding/json" "fmt" + "math" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -22,65 +23,71 @@ const ( // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ // https://openai.com/pricing - "gpt-4": 15, - "gpt-4-0314": 15, - "gpt-4-0613": 15, - "gpt-4-32k": 30, - "gpt-4-32k-0314": 30, - "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-0125-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo": 5, // $0.01 / 1K tokens - "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens - "gpt-4o": 2.5, // $0.005 / 1K tokens - "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens - "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens - "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens - "gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens - "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens - "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens - "gpt-3.5-turbo-0301": 0.75, - "gpt-3.5-turbo-0613": 0.75, - "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens - "gpt-3.5-turbo-16k-0613": 1.5, - "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens - "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens - "o1": 7.5, // $15.00 / 1M input tokens - "o1-2024-12-17": 7.5, - "o1-preview": 7.5, // $15.00 / 1M input tokens - "o1-preview-2024-09-12": 7.5, - "o1-mini": 1.5, // $3.00 / 1M input tokens - "o1-mini-2024-09-12": 1.5, - "davinci-002": 1, // $0.002 / 1K tokens - "babbage-002": 0.2, // $0.0004 / 1K tokens - "text-ada-001": 0.2, - "text-babbage-001": 0.25, - "text-curie-001": 1, - "text-davinci-002": 10, - "text-davinci-003": 10, - "text-davinci-edit-001": 10, - "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "tts-1": 7.5, // $0.015 / 1K characters - "tts-1-1106": 7.5, - "tts-1-hd": 15, // $0.030 / 1K characters - "tts-1-hd-1106": 15, - "davinci": 10, - "curie": 10, - "babbage": 10, - "ada": 10, - "text-embedding-ada-002": 0.05, - "text-embedding-3-small": 0.01, - "text-embedding-3-large": 0.065, - "text-search-ada-doc-001": 10, - "text-moderation-stable": 0.1, - "text-moderation-latest": 0.1, - "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image - "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-0613": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "gpt-4o": 2.5, // $0.005 / 1K tokens + "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens + "gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens + "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens + "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + // Audio billing will mix text and audio tokens, the unit price is different. + // Here records the cost of text, the cost multiplier of audio + // relative to text is in AudioRatio + "gpt-4o-audio-preview": 1.25, // $0.0025 / 1K tokens + "gpt-4o-audio-preview-2024-12-17": 1.25, // $0.0025 / 1K tokens + "gpt-4o-audio-preview-2024-10-01": 1.25, // $0.0025 / 1K tokens + "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens + "gpt-3.5-turbo-0301": 0.75, + "gpt-3.5-turbo-0613": 0.75, + "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": 1.5, + "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens + "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens + "o1": 7.5, // $15.00 / 1M input tokens + "o1-2024-12-17": 7.5, + "o1-preview": 7.5, // $15.00 / 1M input tokens + "o1-preview-2024-09-12": 7.5, + "o1-mini": 1.5, // $3.00 / 1M input tokens + "o1-mini-2024-09-12": 1.5, + "davinci-002": 1, // $0.002 / 1K tokens + "babbage-002": 0.2, // $0.0004 / 1K tokens + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.05, + "text-embedding-3-small": 0.01, + "text-embedding-3-large": 0.065, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, @@ -254,7 +261,6 @@ var ModelRatio = map[string]float64{ "llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD, "llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD, "mixtral-8x7b-32768": 0.24 / 1000000 * USD, - // https://platform.lingyiwanwu.com/docs#-计费单元 "yi-34b-chat-0205": 2.5 / 1000 * RMB, "yi-34b-chat-200k": 12.0 / 1000 * RMB, @@ -333,6 +339,68 @@ var ModelRatio = map[string]float64{ "mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD, } +// AudioRatio represents the price ratio between audio tokens and text tokens +var AudioRatio = map[string]float64{ + "gpt-4o-audio-preview": 16, + "gpt-4o-audio-preview-2024-12-17": 16, + "gpt-4o-audio-preview-2024-10-01": 40, +} + +// GetAudioPromptRatio returns the audio prompt ratio for the given model. +func GetAudioPromptRatio(actualModelName string) float64 { + var v float64 + if ratio, ok := AudioRatio[actualModelName]; ok { + v = ratio + } else { + v = 16 + } + + return v +} + +// AudioCompletionRatio is the completion ratio for audio models. +var AudioCompletionRatio = map[string]float64{ + "whisper-1": 0, + "gpt-4o-audio-preview": 2, + "gpt-4o-audio-preview-2024-12-17": 2, + "gpt-4o-audio-preview-2024-10-01": 2, +} + +// GetAudioCompletionRatio returns the completion ratio for audio models. +func GetAudioCompletionRatio(actualModelName string) float64 { + var v float64 + if ratio, ok := AudioCompletionRatio[actualModelName]; ok { + v = ratio + } else { + v = 2 + } + + return v +} + +// AudioTokensPerSecond is the number of audio tokens per second for each model. +var AudioPromptTokensPerSecond = map[string]float64{ + // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens + "whisper-1": 1000 / 20, + // gpt-4o-audio series processes 10 tokens per second + "gpt-4o-audio-preview": 10, + "gpt-4o-audio-preview-2024-12-17": 10, + "gpt-4o-audio-preview-2024-10-01": 10, +} + +// GetAudioPromptTokensPerSecond returns the number of audio tokens per second +// for the given model. +func GetAudioPromptTokensPerSecond(actualModelName string) int { + var v float64 + if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok { + v = tokensPerSecond + } else { + v = 10 + } + + return int(math.Ceil(v)) +} + var CompletionRatio = map[string]float64{ // aws llama3 "llama3-8b-8192(33)": 0.0006 / 0.0003, @@ -397,19 +465,21 @@ func GetModelRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } + model := fmt.Sprintf("%s(%d)", name, channelType) - if ratio, ok := ModelRatio[model]; ok { - return ratio - } - if ratio, ok := DefaultModelRatio[model]; ok { - return ratio - } - if ratio, ok := ModelRatio[name]; ok { - return ratio - } - if ratio, ok := DefaultModelRatio[name]; ok { - return ratio + + for _, targetName := range []string{model, name} { + for _, ratioMap := range []map[string]float64{ + ModelRatio, + DefaultModelRatio, + AudioRatio, + } { + if ratio, ok := ratioMap[targetName]; ok { + return ratio + } + } } + logger.SysError("model ratio not found: " + name) return 30 } @@ -432,18 +502,19 @@ func GetCompletionRatio(name string, channelType int) float64 { name = strings.TrimSuffix(name, "-internet") } model := fmt.Sprintf("%s(%d)", name, channelType) - if ratio, ok := CompletionRatio[model]; ok { - return ratio - } - if ratio, ok := DefaultCompletionRatio[model]; ok { - return ratio - } - if ratio, ok := CompletionRatio[name]; ok { - return ratio - } - if ratio, ok := DefaultCompletionRatio[name]; ok { - return ratio + + for _, targetName := range []string{model, name} { + for _, ratioMap := range []map[string]float64{ + CompletionRatio, + DefaultCompletionRatio, + AudioCompletionRatio, + } { + if ratio, ok := ratioMap[targetName]; ok { + return ratio + } + } } + if strings.HasPrefix(name, "gpt-3.5") { if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates diff --git a/relay/controller/audio.go b/relay/controller/audio.go index bc756f65..b90666d3 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,10 +7,8 @@ import ( "encoding/json" "fmt" "io" - "math" "mime/multipart" "net/http" - "os" "strings" "github.com/gin-gonic/gin" @@ -23,6 +21,7 @@ import ( "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" + "github.com/songquanpeng/one-api/relay/billing/ratio" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" @@ -30,10 +29,6 @@ import ( "github.com/songquanpeng/one-api/relay/relaymode" ) -const ( - TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens -) - type commonAudioRequest struct { File *multipart.FileHeader `form:"file" binding:"required"` } @@ -54,27 +49,13 @@ func countAudioTokens(c *gin.Context) (int, error) { if err != nil { return 0, errors.WithStack(err) } + defer reqFp.Close() - tmpFp, err := os.CreateTemp("", "audio-*") - if err != nil { - return 0, errors.WithStack(err) - } - defer os.Remove(tmpFp.Name()) + ctxMeta := meta.GetByContext(c) - _, err = io.Copy(tmpFp, reqFp) - if err != nil { - return 0, errors.WithStack(err) - } - if err = tmpFp.Close(); err != nil { - return 0, errors.WithStack(err) - } - - duration, err := helper.GetAudioDuration(c.Request.Context(), tmpFp.Name()) - if err != nil { - return 0, errors.WithStack(err) - } - - return int(math.Ceil(duration)) * TokensPerSecond, nil + return helper.GetAudioTokens(c.Request.Context(), + reqFp, + ratio.GetAudioPromptTokensPerSecond(ctxMeta.ActualModelName)) } func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 5f5fc90c..03d79b3d 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/songquanpeng/one-api/relay/constant/role" "math" "net/http" "strings" @@ -17,6 +16,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/constant/role" "github.com/songquanpeng/one-api/relay/controller/validator" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -42,10 +42,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } -func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { +func getPromptTokens(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: - return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) + return openai.CountTokenMessages(ctx, textRequest.Messages, textRequest.Model) case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) case relaymode.Moderations: diff --git a/relay/controller/text.go b/relay/controller/text.go index 9a47c58b..203719f6 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -4,11 +4,11 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/config" "io" "net/http" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor" @@ -43,7 +43,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio // pre-consume quota - promptTokens := getPromptTokens(textRequest, meta.Mode) + promptTokens := getPromptTokens(c.Request.Context(), textRequest, meta.Mode) meta.PromptTokens = promptTokens preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta) if bizErr != nil { diff --git a/relay/model/general.go b/relay/model/general.go index 288c07ff..5354694c 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -23,36 +23,37 @@ type StreamOptions struct { type GeneralOpenAIRequest struct { // https://platform.openai.com/docs/api-reference/chat/create - Messages []Message `json:"messages,omitempty"` - Model string `json:"model,omitempty"` - Store *bool `json:"store,omitempty"` - Metadata any `json:"metadata,omitempty"` - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - LogitBias any `json:"logit_bias,omitempty"` - Logprobs *bool `json:"logprobs,omitempty"` - TopLogprobs *int `json:"top_logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` - N int `json:"n,omitempty"` - Modalities []string `json:"modalities,omitempty"` - Prediction any `json:"prediction,omitempty"` - Audio *Audio `json:"audio,omitempty"` - PresencePenalty *float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - ServiceTier *string `json:"service_tier,omitempty"` - Stop any `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *StreamOptions `json:"stream_options,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` - User string `json:"user,omitempty"` - FunctionCall any `json:"function_call,omitempty"` - Functions any `json:"functions,omitempty"` + Messages []Message `json:"messages,omitempty"` + Model string `json:"model,omitempty"` + Store *bool `json:"store,omitempty"` + Metadata any `json:"metadata,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + // Modalities currently the model only programmatically allows modalities = [“text”, “audio”] + Modalities []string `json:"modalities,omitempty"` + Prediction any `json:"prediction,omitempty"` + Audio *Audio `json:"audio,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + Stop any `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` // https://platform.openai.com/docs/api-reference/embeddings/create Input any `json:"input,omitempty"` EncodingFormat string `json:"encoding_format,omitempty"` diff --git a/relay/model/message.go b/relay/model/message.go index b908f989..48ddb3ad 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,11 +1,26 @@ package model +import ( + "context" + + "github.com/songquanpeng/one-api/common/logger" +) + type Message struct { - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - Name *string `json:"name,omitempty"` - ToolCalls []Tool `json:"tool_calls,omitempty"` - ToolCallId string `json:"tool_call_id,omitempty"` + Role string `json:"role,omitempty"` + // Content is a string or a list of objects + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` + Audio *messageAudio `json:"audio,omitempty"` +} + +type messageAudio struct { + Id string `json:"id"` + Data string `json:"data,omitempty"` + ExpiredAt int `json:"expired_at,omitempty"` + Transcript string `json:"transcript,omitempty"` } func (m Message) IsStringContent() bool { @@ -26,6 +41,7 @@ func (m Message) StringContent() string { if !ok { continue } + if contentMap["type"] == ContentTypeText { if subStr, ok := contentMap["text"].(string); ok { contentStr += subStr @@ -34,6 +50,7 @@ func (m Message) StringContent() string { } return contentStr } + return "" } @@ -47,6 +64,7 @@ func (m Message) ParseContent() []MessageContent { }) return contentList } + anyList, ok := m.Content.([]any) if ok { for _, contentItem := range anyList { @@ -71,8 +89,21 @@ func (m Message) ParseContent() []MessageContent { }, }) } + case ContentTypeInputAudio: + if subObj, ok := contentMap["input_audio"].(map[string]any); ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeInputAudio, + InputAudio: &InputAudio{ + Data: subObj["data"].(string), + Format: subObj["format"].(string), + }, + }) + } + default: + logger.Warnf(context.TODO(), "unknown content type: %s", contentMap["type"]) } } + return contentList } return nil @@ -84,7 +115,18 @@ type ImageURL struct { } type MessageContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text"` - ImageURL *ImageURL `json:"image_url,omitempty"` + // Type should be one of the following: text/input_audio + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` + InputAudio *InputAudio `json:"input_audio,omitempty"` +} + +type InputAudio struct { + // Data is the base64 encoded audio data + Data string `json:"data" binding:"required"` + // Format is the audio format, should be one of the + // following: mp3/mp4/mpeg/mpga/m4a/wav/webm/pcm16. + // When stream=true, format should be pcm16 + Format string `json:"format"` } diff --git a/relay/model/misc.go b/relay/model/misc.go index 163bc398..ff3f061d 100644 --- a/relay/model/misc.go +++ b/relay/model/misc.go @@ -1,9 +1,13 @@ package model type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails usageCompletionTokensDetails `gorm:"-" json:"completion_tokens_details,omitempty"` + ServiceTier string `gorm:"-" json:"service_tier,omitempty"` + SystemFingerprint string `gorm:"-" json:"system_fingerprint,omitempty"` } type Error struct { @@ -17,3 +21,18 @@ type ErrorWithStatusCode struct { Error StatusCode int `json:"status_code"` } + +type usagePromptTokensDetails struct { + CachedTokens int `json:"cached_tokens"` + AudioTokens int `json:"audio_tokens"` + TextTokens int `json:"text_tokens"` + ImageTokens int `json:"image_tokens"` +} + +type usageCompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` + TextTokens int `json:"text_tokens"` +}