diff --git a/common/model-ratio.go b/common/model-ratio.go index 46505fe..3016f06 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -77,6 +77,35 @@ var ModelRatio = map[string]float64{ "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 } +var ModelPrice = map[string]float64{ + "gpt-4-gizmo-*": 0.1, +} + +func ModelPrice2JSONString() string { + jsonBytes, err := json.Marshal(ModelPrice) + if err != nil { + SysError("error marshalling model price: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateModelPriceByJSONString(jsonStr string) error { + ModelPrice = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &ModelPrice) +} + +func GetModelPrice(name string) float64 { + if strings.HasPrefix(name, "gpt-4-gizmo") { + name = "gpt-4-gizmo-*" + } + price, ok := ModelPrice[name] + if !ok { + //SysError("model price not found: " + name) + return -1 + } + return price +} + func ModelRatio2JSONString() string { jsonBytes, err := json.Marshal(ModelRatio) if err != nil { diff --git a/controller/relay-text.go b/controller/relay-text.go index 82104c4..9a12a6c 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -231,14 +231,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { case RelayModeModerations: promptTokens = countTokenInput(textRequest.Input, textRequest.Model) } - preConsumedTokens := common.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + int(textRequest.MaxTokens) - } - modelRatio := common.GetModelRatio(textRequest.Model) + modelPrice := common.GetModelPrice(textRequest.Model) groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) + + var preConsumedQuota int + var ratio float64 + var modelRatio float64 + if modelPrice == -1 { + preConsumedTokens := common.PreConsumedQuota + if textRequest.MaxTokens != 0 { + preConsumedTokens = promptTokens + int(textRequest.MaxTokens) + } + modelRatio = common.GetModelRatio(textRequest.Model) + ratio = modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + userQuota, err := model.CacheGetUserQuota(userId) if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -447,15 +457,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens completionTokens = textResponse.Usage.CompletionTokens - quota = promptTokens + int(float64(completionTokens)*completionRatio) - quota = int(float64(quota) * ratio) - if ratio != 0 && quota <= 0 { - quota = 1 + quota := 0 + if modelPrice == -1 { + completionRatio := common.GetCompletionRatio(textRequest.Model) + quota = promptTokens + int(float64(completionTokens)*completionRatio) + quota = int(float64(quota) * ratio) + if ratio != 0 && quota <= 0 { + quota = 1 + } + } else { + quota = int(modelPrice * common.QuotaPerUnit * groupRatio) } totalTokens := promptTokens + completionTokens if totalTokens == 0 { @@ -474,7 +488,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } // record all the consume log even if quota is 0 useTimeSeconds := time.Now().Unix() - startTime.Unix() - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds) + var logContent string + if modelPrice == -1 { + logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds) + } else { + logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f,用时 %d秒", modelPrice, groupRatio, useTimeSeconds) + } model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId, userQuota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateChannelUsedQuota(channelId, quota) diff --git a/model/ability.go b/model/ability.go index 4121608..f060991 100644 --- a/model/ability.go +++ b/model/ability.go @@ -6,7 +6,7 @@ import ( ) type Ability struct { - Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` + Group string `json:"group" gorm:"type:varchar(255);primaryKey;autoIncrement:false"` Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` diff --git a/model/channel.go b/model/channel.go index d56c24e..1f7dd2d 100644 --- a/model/channel.go +++ b/model/channel.go @@ -21,7 +21,7 @@ type Channel struct { Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` - Group string `json:"group" gorm:"type:varchar(32);default:'default'"` + Group string `json:"group" gorm:"type:varchar(255);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` diff --git a/model/option.go b/model/option.go index f68a08e..d98904c 100644 --- a/model/option.go +++ b/model/option.go @@ -70,6 +70,7 @@ func InitOptionMap() { common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() + common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["ChatLink"] = common.ChatLink @@ -220,6 +221,8 @@ func updateOptionMap(key string, value string) (err error) { err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": err = common.UpdateGroupRatioByJSONString(value) + case "ModelPrice": + err = common.UpdateModelPriceByJSONString(value) case "TopUpLink": common.TopUpLink = value case "ChatLink": diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index c71de3c..f2713b4 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -10,6 +10,7 @@ const OperationSetting = () => { QuotaRemindThreshold: 0, PreConsumedQuota: 0, ModelRatio: '', + ModelPrice: '', GroupRatio: '', TopUpLink: '', ChatLink: '', @@ -30,7 +31,7 @@ const OperationSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { + if (item.key === 'ModelRatio' || item.key === 'GroupRatio'|| item.key === 'ModelPrice') { item.value = JSON.stringify(JSON.parse(item.value), null, 2); } newInputs[item.key] = item.value; @@ -97,6 +98,13 @@ const OperationSetting = () => { } await updateOption('GroupRatio', inputs.GroupRatio); } + if (originInputs['ModelPrice'] !== inputs.ModelPrice) { + if (!verifyJSON(inputs.ModelPrice)) { + showError('模型固定价格不是合法的 JSON 字符串'); + return; + } + await updateOption('ModelPrice', inputs.ModelPrice); + } break; case 'quota': if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { @@ -315,6 +323,17 @@ const OperationSetting = () => {
倍率设置
+ + +