refactor: add GetRatio to Adaptor

This commit is contained in:
WqyJh
2025-01-14 14:38:26 +08:00
parent 3915ce9814
commit 0ad609ade6
70 changed files with 1038 additions and 467 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
@@ -54,9 +55,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
modelRatio := billingratio.GetModelRatio(audioModel, channelType)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(meta)
groupRatio := billingratio.GetGroupRatio(group)
ratio := modelRatio * groupRatio
adaptorRatio := GetRatio(meta, adaptor)
ratio := adaptorRatio.Input * groupRatio
var quota int64
var preConsumedQuota int64
switch relayMode {
@@ -216,7 +224,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
succeed = true
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, adaptorRatio.Input, groupRatio, audioModel, tokenName, 0, 0)
}(c.Request.Context())
for k, v := range resp.Header {

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"github.com/songquanpeng/one-api/relay/constant/role"
"math"
"net/http"
"strings"
@@ -14,9 +13,12 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/billing/ratio"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/constant/role"
"github.com/songquanpeng/one-api/relay/controller/validator"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
@@ -91,17 +93,26 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil
}
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio billingratio.Ratio, preConsumedQuota int64, groupRatio float64, systemPromptReset bool) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
var quota int64
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
// use meta.OriginalModelName instead of mapped model name, which may named randomly in azure
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
promptRatio := ratio.Input
completionRatio := ratio.Output
// for gemini, prompt longer than 128k will be charged as long input
if ratio.LongInput > 0 && promptTokens > ratio.LongThreshold {
promptRatio = ratio.LongInput
completionRatio = ratio.LongOutput
}
quota = int64(math.Ceil(groupRatio * (float64(promptTokens)*promptRatio + float64(completionTokens)*completionRatio)))
if quota <= 0 && (ratio.Input > 0 || ratio.Output > 0) {
quota = 1
}
totalTokens := promptTokens + completionTokens
@@ -123,8 +134,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
if systemPromptReset {
extraLog = " (注意系统提示词已被重置)"
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", promptRatio, groupRatio, completionRatio/promptRatio, extraLog)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, meta.OriginModelName, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
@@ -185,3 +196,16 @@ func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIReque
logger.Infof(ctx, "add system prompt")
return true
}
func GetRatio(meta *meta.Meta, adaptor adaptor.Adaptor) ratio.Ratio {
result := billingratio.GetRatio(meta.OriginModelName, meta.ChannelType)
if result != nil {
return *result
}
ratio := adaptor.GetRatio(meta)
if ratio != nil {
return *ratio
}
logger.SysError("model ratio not found: " + meta.OriginModelName)
return billingratio.FallbackRatio
}

View File

@@ -128,7 +128,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
}
imageModel := imageRequest.Model
// Convert the original image model
imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
c.Set("response_format", imageRequest.ResponseFormat)
@@ -167,9 +166,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
adaptorRatio := GetRatio(meta, adaptor)
ratio := adaptorRatio.Input * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
var quota int64
@@ -209,7 +208,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
if quota != 0 {
tokenName := c.GetString(ctxkey.TokenName)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", adaptorRatio.Input, groupRatio)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
channelId := c.GetInt(ctxkey.ChannelId)

View File

@@ -4,10 +4,11 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"io"
"net/http"
"github.com/songquanpeng/one-api/common/config"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
@@ -32,6 +33,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
}
meta.IsStream = textRequest.Stream
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(meta)
// map model name
meta.OriginModelName = textRequest.Model
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
@@ -39,9 +46,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// set system prompt if not empty
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
adaptorRatio := GetRatio(meta, adaptor)
ratio := adaptorRatio.Input * groupRatio
// pre-consume quota
promptTokens := getPromptTokens(textRequest, meta.Mode)
meta.PromptTokens = promptTokens
@@ -51,12 +59,6 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return bizErr
}
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(meta)
// get request body
requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
if err != nil {
@@ -82,7 +84,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return respErr
}
// post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
go postConsumeQuota(ctx, usage, meta, textRequest, adaptorRatio, preConsumedQuota, groupRatio, systemPromptReset)
return nil
}