From 24b3ed50d75af7b2c243ddb2ea15ae4b93fa8f62 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Mon, 7 Oct 2024 19:08:20 +0800 Subject: [PATCH] feat: realtime pre consume (cherry picked from commit d87917f8f6eb9d2e144a9f840d6d91767ea2eb69) --- relay/channel/openai/relay-openai.go | 51 ++++++++++- relay/common/relay_info.go | 1 + relay/websocket.go | 92 +------------------ service/quota.go | 132 +++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 95 deletions(-) create mode 100644 service/quota.go diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 6b11aac..60d09a0 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -389,6 +389,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op usage := &dto.RealtimeUsage{} localUsage := &dto.RealtimeUsage{} + sumUsage := &dto.RealtimeUsage{} go func() { for { @@ -478,6 +479,12 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens + err := preConsumeUsage(c, info, usage, sumUsage) + if err != nil { + errChan <- fmt.Errorf("error consume usage: %v", err) + return + } + usage = &dto.RealtimeUsage{} } else { textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) if err != nil { @@ -490,7 +497,18 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.AudioTokens += audioToken + err = preConsumeUsage(c, info, localUsage, sumUsage) + if err != nil { + errChan <- fmt.Errorf("error consume usage: %v", err) + return + } + localUsage = &dto.RealtimeUsage{} + // print now usage } + common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session if realtimeSession != nil { @@ -528,15 +546,38 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op select { case <-clientClosed: case <-targetClosed: - case <-errChan: + case err := <-errChan: //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil + common.LogError(c, "realtime error: "+err.Error()) case <-c.Done(): } + if usage.TotalTokens != 0 { + _ = preConsumeUsage(c, info, usage, sumUsage) + } + + if localUsage.TotalTokens != 0 { + _ = preConsumeUsage(c, info, localUsage, sumUsage) + } + // check usage total tokens, if 0, use local usage - if usage.TotalTokens == 0 { - usage = localUsage - } - return nil, usage + return nil, sumUsage +} + +func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error { + totalUsage.TotalTokens += usage.TotalTokens + totalUsage.InputTokens += usage.InputTokens + totalUsage.OutputTokens += usage.OutputTokens + totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens + totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens + totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens + totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens + totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens + // clear usage + err := service.PreWssConsumeQuota(ctx, info, usage) + if err == nil { + common.LogInfo(ctx, "realtime streaming consume usage success") + } + return err } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b43f917..21e3691 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -23,6 +23,7 @@ type RelayInfo struct { ApiType int IsStream bool IsPlayground bool + UsePrice bool RelayMode int UpstreamModelName string OriginModelName string diff --git a/relay/websocket.go b/relay/websocket.go index 089805d..09d8298 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -5,15 +5,11 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "math" "net/http" "one-api/common" "one-api/dto" - "one-api/model" relaycommon "one-api/relay/common" "one-api/service" - "strings" - "time" ) //func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) { @@ -91,6 +87,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + relayInfo.UsePrice = true } // pre-consume quota 预消耗配额 @@ -126,95 +123,10 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") return nil } -func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, - groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { - - useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() - textInputTokens := usage.InputTokenDetails.TextTokens - textOutTokens := usage.OutputTokenDetails.TextTokens - - audioInputTokens := usage.InputTokenDetails.AudioTokens - audioOutTokens := usage.OutputTokenDetails.AudioTokens - - tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(modelName) - audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(modelName) - - quota := 0 - if !usePrice { - quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) - - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - } else { - quota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } - totalTokens := usage.TotalTokens - var logContent string - if !usePrice { - logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio) - } else { - logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) - } - - // record all the consume log even if quota is 0 - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) - } else { - //if sensitiveResp != nil { - // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) - //} - quotaDelta := quota - preConsumedQuota - if quotaDelta != 0 { - err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - } - err := model.CacheUpdateUserQuota(relayInfo.UserId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) - model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) - } - - logModel := modelName - if strings.HasPrefix(logModel, "gpt-4-gizmo") { - logModel = "gpt-4-gizmo-*" - logContent += fmt.Sprintf(",模型 %s", modelName) - } - if strings.HasPrefix(logModel, "gpt-4o-gizmo") { - logModel = "gpt-4o-gizmo-*" - logContent += fmt.Sprintf(",模型 %s", modelName) - } - if extraContent != "" { - logContent += ", " + extraContent - } - other := service.GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) - - //if quota != 0 { - // - //} -} - //func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { // var promptTokens int // var err error diff --git a/service/quota.go b/service/quota.go new file mode 100644 index 0000000..09c2fd5 --- /dev/null +++ b/service/quota.go @@ -0,0 +1,132 @@ +package service + +import ( + "fmt" + "github.com/gin-gonic/gin" + "math" + "one-api/common" + "one-api/dto" + "one-api/model" + relaycommon "one-api/relay/common" + "strings" + "time" +) + +func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error { + if relayInfo.UsePrice { + return nil + } + modelName := relayInfo.UpstreamModelName + textInputTokens := usage.InputTokenDetails.TextTokens + textOutTokens := usage.OutputTokenDetails.TextTokens + audioInputTokens := usage.InputTokenDetails.AudioTokens + audioOutTokens := usage.OutputTokenDetails.AudioTokens + + completionRatio := common.GetCompletionRatio(modelName) + audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) + audioCompletionRatio := common.GetAudioCompletionRatio(modelName) + groupRatio := common.GetGroupRatio(relayInfo.Group) + modelRatio := common.GetModelRatio(modelName) + + ratio := groupRatio * modelRatio + + quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) + quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) + + quota = int(math.Round(float64(quota) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + + err := model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false) + if err != nil { + return err + } + err = model.CacheUpdateUserQuota(relayInfo.UserId) + if err != nil { + return err + } + return nil +} + +func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, + usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, + groupRatio float64, + modelPrice float64, usePrice bool, extraContent string) { + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + textInputTokens := usage.InputTokenDetails.TextTokens + textOutTokens := usage.OutputTokenDetails.TextTokens + + audioInputTokens := usage.InputTokenDetails.AudioTokens + audioOutTokens := usage.OutputTokenDetails.AudioTokens + + tokenName := ctx.GetString("token_name") + completionRatio := common.GetCompletionRatio(modelName) + audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) + audioCompletionRatio := common.GetAudioCompletionRatio(modelName) + + quota := 0 + if !usePrice { + quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) + quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) + + quota = int(math.Round(float64(quota) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + } else { + quota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + totalTokens := usage.TotalTokens + var logContent string + if !usePrice { + logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio) + } else { + logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) + } + + // record all the consume log even if quota is 0 + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + logContent += fmt.Sprintf("(可能是上游超时)") + common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + } else { + //if sensitiveResp != nil { + // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + //} + //quotaDelta := quota - preConsumedQuota + //if quotaDelta != 0 { + // err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + // if err != nil { + // common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + // } + //} + + //err := model.CacheUpdateUserQuota(relayInfo.UserId) + //if err != nil { + // common.LogError(ctx, "error update user quota cache: "+err.Error()) + //} + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + logModel := modelName + if strings.HasPrefix(logModel, "gpt-4-gizmo") { + logModel = "gpt-4-gizmo-*" + logContent += fmt.Sprintf(",模型 %s", modelName) + } + if strings.HasPrefix(logModel, "gpt-4o-gizmo") { + logModel = "gpt-4o-gizmo-*" + logContent += fmt.Sprintf(",模型 %s", modelName) + } + if extraContent != "" { + logContent += ", " + extraContent + } + other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice) + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) +}