From 1f103202a6a233bac5fbc438067a558cc3602e4e Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Fri, 22 Nov 2024 03:12:09 +0000 Subject: [PATCH] fix: refactor postConsumeQuota function to return quota and update user request cost handling --- relay/controller/helper.go | 6 +++--- relay/controller/text.go | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 74a9858e..d859532e 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -91,7 +91,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) (quota int64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return @@ -100,7 +100,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens - quota := int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } @@ -128,7 +128,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) - return + return quota } func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { diff --git a/relay/controller/text.go b/relay/controller/text.go index 84d3dff8..a51c9241 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -11,6 +11,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -90,7 +91,22 @@ func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { } // post-consume quota - go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) + go func() { + quota := postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) + + // also update user request cost + if quota != 0 { + docu := model.NewUserRequestCost( + c.GetInt(ctxkey.Id), + c.GetString(ctxkey.RequestId), + quota, + ) + if err = docu.Insert(); err != nil { + logger.Errorf(c, "insert user request cost failed: %+v", err) + } + } + }() + return nil }