mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-11 19:03:43 +08:00
refactor: add GetRatio to Adaptor
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
@@ -54,9 +55,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
}
|
||||
|
||||
modelRatio := billingratio.GetModelRatio(audioModel, channelType)
|
||||
adaptor := relay.GetAdaptor(meta.APIType)
|
||||
if adaptor == nil {
|
||||
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
|
||||
groupRatio := billingratio.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
adaptorRatio := GetRatio(meta, adaptor)
|
||||
ratio := adaptorRatio.Input * groupRatio
|
||||
|
||||
var quota int64
|
||||
var preConsumedQuota int64
|
||||
switch relayMode {
|
||||
@@ -216,7 +224,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
succeed = true
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
defer func(ctx context.Context) {
|
||||
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, adaptorRatio.Input, groupRatio, audioModel, tokenName, 0, 0)
|
||||
}(c.Request.Context())
|
||||
|
||||
for k, v := range resp.Header {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/relay/constant/role"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -14,9 +13,12 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/constant/role"
|
||||
"github.com/songquanpeng/one-api/relay/controller/validator"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
@@ -91,17 +93,26 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
|
||||
return preConsumedQuota, nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
|
||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio billingratio.Ratio, preConsumedQuota int64, groupRatio float64, systemPromptReset bool) {
|
||||
if usage == nil {
|
||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||
return
|
||||
}
|
||||
var quota int64
|
||||
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
|
||||
// use meta.OriginalModelName instead of mapped model name, which may named randomly in azure
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
promptRatio := ratio.Input
|
||||
completionRatio := ratio.Output
|
||||
|
||||
// for gemini, prompt longer than 128k will be charged as long input
|
||||
if ratio.LongInput > 0 && promptTokens > ratio.LongThreshold {
|
||||
promptRatio = ratio.LongInput
|
||||
completionRatio = ratio.LongOutput
|
||||
}
|
||||
quota = int64(math.Ceil(groupRatio * (float64(promptTokens)*promptRatio + float64(completionTokens)*completionRatio)))
|
||||
|
||||
if quota <= 0 && (ratio.Input > 0 || ratio.Output > 0) {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
@@ -123,8 +134,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
|
||||
if systemPromptReset {
|
||||
extraLog = " (注意系统提示词已被重置)"
|
||||
}
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog)
|
||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", promptRatio, groupRatio, completionRatio/promptRatio, extraLog)
|
||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, meta.OriginModelName, meta.TokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||
}
|
||||
@@ -185,3 +196,16 @@ func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIReque
|
||||
logger.Infof(ctx, "add system prompt")
|
||||
return true
|
||||
}
|
||||
|
||||
func GetRatio(meta *meta.Meta, adaptor adaptor.Adaptor) ratio.Ratio {
|
||||
result := billingratio.GetRatio(meta.OriginModelName, meta.ChannelType)
|
||||
if result != nil {
|
||||
return *result
|
||||
}
|
||||
ratio := adaptor.GetRatio(meta)
|
||||
if ratio != nil {
|
||||
return *ratio
|
||||
}
|
||||
logger.SysError("model ratio not found: " + meta.OriginModelName)
|
||||
return billingratio.FallbackRatio
|
||||
}
|
||||
|
||||
@@ -128,7 +128,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
imageModel := imageRequest.Model
|
||||
// Convert the original image model
|
||||
imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
|
||||
c.Set("response_format", imageRequest.ResponseFormat)
|
||||
@@ -167,9 +166,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
adaptorRatio := GetRatio(meta, adaptor)
|
||||
ratio := adaptorRatio.Input * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||
|
||||
var quota int64
|
||||
@@ -209,7 +208,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString(ctxkey.TokenName)
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", adaptorRatio.Input, groupRatio)
|
||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
|
||||
@@ -4,10 +4,11 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
@@ -32,6 +33,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
}
|
||||
meta.IsStream = textRequest.Stream
|
||||
|
||||
adaptor := relay.GetAdaptor(meta.APIType)
|
||||
if adaptor == nil {
|
||||
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
|
||||
// map model name
|
||||
meta.OriginModelName = textRequest.Model
|
||||
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
@@ -39,9 +46,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
// set system prompt if not empty
|
||||
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
|
||||
// get model ratio & group ratio
|
||||
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
adaptorRatio := GetRatio(meta, adaptor)
|
||||
ratio := adaptorRatio.Input * groupRatio
|
||||
|
||||
// pre-consume quota
|
||||
promptTokens := getPromptTokens(textRequest, meta.Mode)
|
||||
meta.PromptTokens = promptTokens
|
||||
@@ -51,12 +59,6 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
return bizErr
|
||||
}
|
||||
|
||||
adaptor := relay.GetAdaptor(meta.APIType)
|
||||
if adaptor == nil {
|
||||
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
|
||||
// get request body
|
||||
requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
|
||||
if err != nil {
|
||||
@@ -82,7 +84,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
return respErr
|
||||
}
|
||||
// post-consume quota
|
||||
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
|
||||
go postConsumeQuota(ctx, usage, meta, textRequest, adaptorRatio, preConsumedQuota, groupRatio, systemPromptReset)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user