From 71547849bc4f2eb374294bfb79c529a39c01b55f Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 13 May 2024 16:04:02 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20dalle=E7=B3=BB=E5=88=97=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E4=BD=BF=E7=94=A8=E6=A8=A1=E5=9E=8B=E5=9B=BA=E5=AE=9A?= =?UTF-8?q?=E4=BB=B7=E6=A0=BC=E8=AE=A1=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/model-ratio.go | 11 ++++++----- relay/relay-image.go | 25 ++++++++++++++++--------- relay/relay-mj.go | 8 ++++---- relay/relay-text.go | 10 +++++----- web/src/helpers/render.js | 2 +- 5 files changed, 32 insertions(+), 24 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index f470317..a8db3b3 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -61,8 +61,6 @@ var DefaultModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e-2": 8, - "dall-e-3": 16, "claude-instant-1": 0.4, // $0.8 / 1M tokens "claude-2.0": 4, // $8 / 1M tokens "claude-2.1": 4, // $8 / 1M tokens @@ -117,6 +115,8 @@ var DefaultModelRatio = map[string]float64{ } var DefaultModelPrice = map[string]float64{ + "dall-e-2": 0.02, + "dall-e-3": 0.04, "gpt-4-gizmo-*": 0.1, "mj_imagine": 0.1, "mj_variation": 0.1, @@ -160,7 +160,8 @@ func UpdateModelPriceByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &modelPrice) } -func GetModelPrice(name string, printErr bool) float64 { +// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false +func GetModelPrice(name string, printErr bool) (float64, bool) { if modelPrice == nil { modelPrice = DefaultModelPrice } @@ -172,9 +173,9 @@ func GetModelPrice(name string, printErr bool) float64 { if printErr { SysError("model price not found: " + name) } - return -1 + return -1, false } - return price + return price, true } func ModelRatio2JSONString() string { diff --git a/relay/relay-image.go b/relay/relay-image.go index 7f8cd9e..346d72d 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "io" + "log" "net/http" "one-api/common" "one-api/dto" @@ -106,21 +107,27 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageRequest.Model) + modelPrice, success := common.GetModelPrice(imageRequest.Model, true) + if !success { + modelRatio := common.GetModelRatio(imageRequest.Model) + // modelRatio 16 = modelPrice $0.04 + // per 1 modelRatio = $0.04 / 16 + modelPrice = 0.0025 * modelRatio + } + log.Printf("modelPrice: %f", modelPrice) groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(userId) sizeRatio := 1.0 // Size if imageRequest.Size == "256x256" { - sizeRatio = 1 + sizeRatio = 0.4 } else if imageRequest.Size == "512x512" { - sizeRatio = 1.125 + sizeRatio = 0.45 } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1.25 + sizeRatio = 1 } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - sizeRatio = 2.5 + sizeRatio = 2 } qualityRatio := 1.0 @@ -131,7 +138,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } } - quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N + quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N if userQuota-quota < 0 { return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -190,9 +197,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC if imageRequest.Quality == "hd" { quality = "hd" } - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality) + logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality) other := make(map[string]interface{}) - other["model_ratio"] = modelRatio + other["model_price"] = modelPrice other["group_ratio"] = groupRatio model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 16ad412..b28f026 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -155,9 +155,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - modelPrice := common.GetModelPrice(modelName, true) + modelPrice, success := common.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 - if modelPrice == -1 { + if !success { defaultPrice, ok := common.DefaultModelPrice[modelName] if !ok { modelPrice = 0.1 @@ -454,9 +454,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) modelName := service.CoverActionToModelName(midjRequest.Action) - modelPrice := common.GetModelPrice(modelName, true) + modelPrice, success := common.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 - if modelPrice == -1 { + if !success { defaultPrice, ok := common.DefaultModelPrice[modelName] if !ok { modelPrice = 0.1 diff --git a/relay/relay-text.go b/relay/relay-text.go index bf3cba0..d5ee728 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } } relayInfo.UpstreamModelName = textRequest.Model - modelPrice := common.GetModelPrice(textRequest.Model, false) + modelPrice, success := common.GetModelPrice(textRequest.Model, false) groupRatio := common.GetGroupRatio(relayInfo.Group) var preConsumedQuota int @@ -108,7 +108,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } - if modelPrice == -1 { + if !success { preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + int(textRequest.MaxTokens) @@ -178,7 +178,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) + postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) return nil } @@ -257,7 +257,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64) { + modelPrice float64, usePrice bool) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -267,7 +267,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe completionRatio := common.GetCompletionRatio(textRequest.Model) quota := 0 - if modelPrice == -1 { + if !usePrice { quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio)) quota = int(math.Round(float64(quota) * ratio)) if ratio != 0 && quota <= 0 { diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index f0cbd81..3113fed 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -159,7 +159,7 @@ export function renderModelPrice(

提示 ${inputRatioPrice} / 1M tokens

补全 ${completionRatioPrice} / 1M tokens

-

计算过程:

+

提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '} {completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $