feat: dalle系列改为使用模型固定价格计费

This commit is contained in:
CaIon 2024-05-13 16:04:02 +08:00
parent 39f6812a2b
commit 71547849bc
5 changed files with 32 additions and 24 deletions

View File

@ -61,8 +61,6 @@ var DefaultModelRatio = map[string]float64{
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"dall-e-2": 8,
"dall-e-3": 16,
"claude-instant-1": 0.4, // $0.8 / 1M tokens "claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens "claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens "claude-2.1": 4, // $8 / 1M tokens
@ -117,6 +115,8 @@ var DefaultModelRatio = map[string]float64{
} }
var DefaultModelPrice = map[string]float64{ var DefaultModelPrice = map[string]float64{
"dall-e-2": 0.02,
"dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1, "gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1, "mj_imagine": 0.1,
"mj_variation": 0.1, "mj_variation": 0.1,
@ -160,7 +160,8 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &modelPrice) return json.Unmarshal([]byte(jsonStr), &modelPrice)
} }
func GetModelPrice(name string, printErr bool) float64 { // GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) {
if modelPrice == nil { if modelPrice == nil {
modelPrice = DefaultModelPrice modelPrice = DefaultModelPrice
} }
@ -172,9 +173,9 @@ func GetModelPrice(name string, printErr bool) float64 {
if printErr { if printErr {
SysError("model price not found: " + name) SysError("model price not found: " + name)
} }
return -1 return -1, false
} }
return price return price, true
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
@ -106,21 +107,27 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
requestBody = c.Request.Body requestBody = c.Request.Body
} }
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success {
modelRatio := common.GetModelRatio(imageRequest.Model) modelRatio := common.GetModelRatio(imageRequest.Model)
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
modelPrice = 0.0025 * modelRatio
}
log.Printf("modelPrice: %f", modelPrice)
groupRatio := common.GetGroupRatio(group) groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0 sizeRatio := 1.0
// Size // Size
if imageRequest.Size == "256x256" { if imageRequest.Size == "256x256" {
sizeRatio = 1 sizeRatio = 0.4
} else if imageRequest.Size == "512x512" { } else if imageRequest.Size == "512x512" {
sizeRatio = 1.125 sizeRatio = 0.45
} else if imageRequest.Size == "1024x1024" { } else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25 sizeRatio = 1
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
sizeRatio = 2.5 sizeRatio = 2
} }
qualityRatio := 1.0 qualityRatio := 1.0
@ -131,7 +138,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
} }
} }
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
if userQuota-quota < 0 { if userQuota-quota < 0 {
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@ -190,9 +197,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
if imageRequest.Quality == "hd" { if imageRequest.Quality == "hd" {
quality = "hd" quality = "hd"
} }
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality) logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality)
other := make(map[string]interface{}) other := make(map[string]interface{})
other["model_ratio"] = modelRatio other["model_price"] = modelPrice
other["group_ratio"] = groupRatio other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)

View File

@ -155,9 +155,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
} }
modelName := service.CoverActionToModelName(constant.MjActionSwapFace) modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelPrice := common.GetModelPrice(modelName, true) modelPrice, success := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格 // 如果没有配置价格,则使用默认价格
if modelPrice == -1 { if !success {
defaultPrice, ok := common.DefaultModelPrice[modelName] defaultPrice, ok := common.DefaultModelPrice[modelName]
if !ok { if !ok {
modelPrice = 0.1 modelPrice = 0.1
@ -454,9 +454,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
modelName := service.CoverActionToModelName(midjRequest.Action) modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice := common.GetModelPrice(modelName, true) modelPrice, success := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格 // 如果没有配置价格,则使用默认价格
if modelPrice == -1 { if !success {
defaultPrice, ok := common.DefaultModelPrice[modelName] defaultPrice, ok := common.DefaultModelPrice[modelName]
if !ok { if !ok {
modelPrice = 0.1 modelPrice = 0.1

View File

@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
} }
} }
relayInfo.UpstreamModelName = textRequest.Model relayInfo.UpstreamModelName = textRequest.Model
modelPrice := common.GetModelPrice(textRequest.Model, false) modelPrice, success := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := common.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int var preConsumedQuota int
@ -108,7 +108,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
} }
if modelPrice == -1 { if !success {
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 { if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens) preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
@ -178,7 +178,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
return nil return nil
} }
@ -257,7 +257,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64) { modelPrice float64, usePrice bool) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
@ -267,7 +267,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
completionRatio := common.GetCompletionRatio(textRequest.Model) completionRatio := common.GetCompletionRatio(textRequest.Model)
quota := 0 quota := 0
if modelPrice == -1 { if !usePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio)) quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio)) quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {

View File

@ -159,7 +159,7 @@ export function renderModelPrice(
<article> <article>
<p>提示 ${inputRatioPrice} / 1M tokens</p> <p>提示 ${inputRatioPrice} / 1M tokens</p>
<p>补全 ${completionRatioPrice} / 1M tokens</p> <p>补全 ${completionRatioPrice} / 1M tokens</p>
<p>计算过程</p> <p> </p>
<p> <p>
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '} 提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $ {completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $