diff --git a/common/model-ratio.go b/common/model-ratio.go
index 46505fe..3016f06 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -77,6 +77,35 @@ var ModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
}
+var ModelPrice = map[string]float64{
+ "gpt-4-gizmo-*": 0.1,
+}
+
+func ModelPrice2JSONString() string {
+ jsonBytes, err := json.Marshal(ModelPrice)
+ if err != nil {
+ SysError("error marshalling model price: " + err.Error())
+ }
+ return string(jsonBytes)
+}
+
+func UpdateModelPriceByJSONString(jsonStr string) error {
+ ModelPrice = make(map[string]float64)
+ return json.Unmarshal([]byte(jsonStr), &ModelPrice)
+}
+
+func GetModelPrice(name string) float64 {
+ if strings.HasPrefix(name, "gpt-4-gizmo") {
+ name = "gpt-4-gizmo-*"
+ }
+ price, ok := ModelPrice[name]
+ if !ok {
+ //SysError("model price not found: " + name)
+ return -1
+ }
+ return price
+}
+
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
diff --git a/controller/relay-text.go b/controller/relay-text.go
index 82104c4..9a12a6c 100644
--- a/controller/relay-text.go
+++ b/controller/relay-text.go
@@ -231,14 +231,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
- preConsumedTokens := common.PreConsumedQuota
- if textRequest.MaxTokens != 0 {
- preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
- }
- modelRatio := common.GetModelRatio(textRequest.Model)
+ modelPrice := common.GetModelPrice(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
- ratio := modelRatio * groupRatio
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+
+ var preConsumedQuota int
+ var ratio float64
+ var modelRatio float64
+ if modelPrice == -1 {
+ preConsumedTokens := common.PreConsumedQuota
+ if textRequest.MaxTokens != 0 {
+ preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
+ }
+ modelRatio = common.GetModelRatio(textRequest.Model)
+ ratio = modelRatio * groupRatio
+ preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+ } else {
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+ }
+
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -447,15 +457,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
- quota := 0
- completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
- quota = promptTokens + int(float64(completionTokens)*completionRatio)
- quota = int(float64(quota) * ratio)
- if ratio != 0 && quota <= 0 {
- quota = 1
+ quota := 0
+ if modelPrice == -1 {
+ completionRatio := common.GetCompletionRatio(textRequest.Model)
+ quota = promptTokens + int(float64(completionTokens)*completionRatio)
+ quota = int(float64(quota) * ratio)
+ if ratio != 0 && quota <= 0 {
+ quota = 1
+ }
+ } else {
+ quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
@@ -474,7 +488,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
// record all the consume log even if quota is 0
useTimeSeconds := time.Now().Unix() - startTime.Unix()
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
+ var logContent string
+ if modelPrice == -1 {
+ logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
+ } else {
+ logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f,用时 %d秒", modelPrice, groupRatio, useTimeSeconds)
+ }
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
diff --git a/model/ability.go b/model/ability.go
index 4121608..f060991 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -6,7 +6,7 @@ import (
)
type Ability struct {
- Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
+ Group string `json:"group" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
diff --git a/model/channel.go b/model/channel.go
index d56c24e..1f7dd2d 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -21,7 +21,7 @@ type Channel struct {
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
- Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
+ Group string `json:"group" gorm:"type:varchar(255);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
diff --git a/model/option.go b/model/option.go
index f68a08e..d98904c 100644
--- a/model/option.go
+++ b/model/option.go
@@ -70,6 +70,7 @@ func InitOptionMap() {
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
+ common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
@@ -220,6 +221,8 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
+ case "ModelPrice":
+ err = common.UpdateModelPriceByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
case "ChatLink":
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index c71de3c..f2713b4 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -10,6 +10,7 @@ const OperationSetting = () => {
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
+ ModelPrice: '',
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
@@ -30,7 +31,7 @@ const OperationSetting = () => {
if (success) {
let newInputs = {};
data.forEach((item) => {
- if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
+ if (item.key === 'ModelRatio' || item.key === 'GroupRatio'|| item.key === 'ModelPrice') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
newInputs[item.key] = item.value;
@@ -97,6 +98,13 @@ const OperationSetting = () => {
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
+ if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
+ if (!verifyJSON(inputs.ModelPrice)) {
+ showError('模型固定价格不是合法的 JSON 字符串');
+ return;
+ }
+ await updateOption('ModelPrice', inputs.ModelPrice);
+ }
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
@@ -315,6 +323,17 @@ const OperationSetting = () => {