diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 115558a5..75c6da51 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -21,4 +21,5 @@ const ( AvailableModels = "available_models" KeyRequestBody = "key_request_body" SystemPrompt = "system_prompt" + Meta = "meta" ) diff --git a/relay/adaptor/proxy/adaptor.go b/relay/adaptor/proxy/adaptor.go index 670c7628..06fddee0 100644 --- a/relay/adaptor/proxy/adaptor.go +++ b/relay/adaptor/proxy/adaptor.go @@ -60,7 +60,6 @@ func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId) return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil - } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 14a23a51..d52de788 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -71,7 +71,7 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": 15, "tts-1": 7.5, // $0.015 / 1K characters "tts-1-1106": 7.5, "tts-1-hd": 15, // $0.030 / 1K characters @@ -380,8 +380,9 @@ func GetAudioCompletionRatio(actualModelName string) float64 { // AudioTokensPerSecond is the number of audio tokens per second for each model. var AudioPromptTokensPerSecond = map[string]float64{ - // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens - "whisper-1": 1000 / 20, + // whisper 的 API 价格是 $0.0001/sec。one-api 的历史倍率为 15,对应 $0.03/kilo_tokens。 + // 那么换算后可得,每秒的 tokens 应该为 0.0001/0.03*1000 = 3.3333 + "whisper-1": 0.0001 / 0.03 * 1000, // gpt-4o-audio series processes 10 tokens per second "gpt-4o-audio-preview": 10, "gpt-4o-audio-preview-2024-12-17": 10, diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 03d79b3d..d8937224 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -129,17 +129,6 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M model.UpdateChannelUsedQuota(meta.ChannelId, quota) } -func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { - if mapping == nil { - return modelName, false - } - mappedModelName := mapping[modelName] - if mappedModelName != "" { - return mappedModelName, true - } - return modelName, false -} - func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { if resp == nil { if meta.ChannelType == channeltype.AwsClaude { diff --git a/relay/controller/image.go b/relay/controller/image.go index 1b69d97d..581859f1 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -18,7 +18,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" - "github.com/songquanpeng/one-api/relay/meta" + metalib "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" ) @@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 { return 1 } -func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *metalib.Meta) *relaymodel.ErrorWithStatusCode { // check prompt length if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) @@ -104,7 +104,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := meta.GetByContext(c) + meta := metalib.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) @@ -114,7 +114,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus // map model name var isModelMapped bool meta.OriginModelName = imageRequest.Model - imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) + imageRequest.Model = meta.ActualModelName + isModelMapped = meta.OriginModelName != meta.ActualModelName meta.ActualModelName = imageRequest.Model // model validation @@ -130,7 +131,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus imageModel := imageRequest.Model // Convert the original image model - imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName) + imageRequest.Model = metalib.GetMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName) c.Set("response_format", imageRequest.ResponseFormat) var requestBody io.Reader diff --git a/relay/controller/text.go b/relay/controller/text.go index 203719f6..69a51386 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -17,13 +17,13 @@ import ( "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" - "github.com/songquanpeng/one-api/relay/meta" - "github.com/songquanpeng/one-api/relay/model" + metalib "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { +func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := meta.GetByContext(c) + meta := metalib.GetByContext(c) // get & validate textRequest textRequest, err := getAndValidateTextRequest(c, meta.Mode) if err != nil { @@ -34,7 +34,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { // map model name meta.OriginModelName = textRequest.Model - textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) + textRequest.Model = meta.ActualModelName meta.ActualModelName = textRequest.Model // set system prompt if not empty systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt) @@ -86,9 +86,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { return nil } -func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { - if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { - // no need to convert request for openai +func getRequestBody(c *gin.Context, meta *metalib.Meta, textRequest *relaymodel.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { + if !config.EnforceIncludeUsage && + meta.APIType == apitype.OpenAI && + meta.OriginModelName == meta.ActualModelName && + meta.ChannelType != channeltype.OpenAI && // openai also need to convert request + meta.ChannelType != channeltype.Baichuan { return c.Request.Body, nil } diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index bcbe1045..02b19504 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -1,12 +1,13 @@ package meta import ( + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/relaymode" - "strings" ) type Meta struct { @@ -33,6 +34,20 @@ type Meta struct { SystemPrompt string } +// GetMappedModelName returns the mapped model name and a bool indicating if the model name is mapped +func GetMappedModelName(modelName string, mapping map[string]string) string { + if mapping == nil { + return modelName + } + + mappedModelName := mapping[modelName] + if mappedModelName != "" { + return mappedModelName + } + + return modelName +} + func GetByContext(c *gin.Context) *Meta { meta := Meta{ Mode: relaymode.GetByPath(c.Request.URL.Path), @@ -44,6 +59,7 @@ func GetByContext(c *gin.Context) *Meta { Group: c.GetString(ctxkey.Group), ModelMapping: c.GetStringMapString(ctxkey.ModelMapping), OriginModelName: c.GetString(ctxkey.RequestModel), + ActualModelName: c.GetString(ctxkey.RequestModel), BaseURL: c.GetString(ctxkey.BaseURL), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), @@ -57,5 +73,13 @@ func GetByContext(c *gin.Context) *Meta { meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] } meta.APIType = channeltype.ToAPIType(meta.ChannelType) + + meta.ActualModelName = GetMappedModelName(meta.OriginModelName, meta.ModelMapping) + + Set2Context(c, &meta) return &meta } + +func Set2Context(c *gin.Context, meta *Meta) { + c.Set(ctxkey.Meta, meta) +}