diff --git a/model/option.go b/model/option.go index c749b231..030f0b54 100644 --- a/model/option.go +++ b/model/option.go @@ -83,15 +83,35 @@ func InitOptionMap() { func loadOptionsFromDatabase() { options, _ := AllOption() + var oldModelRatio string + var oldCompletionRatio string for _, option := range options { if option.Key == "ModelRatio" { + oldModelRatio = option.Value option.Value = billingratio.AddNewMissingRatio(option.Value) } + if option.Key == "CompletionRatio" { + oldCompletionRatio = option.Value + } err := updateOptionMap(option.Key, option.Value) if err != nil { logger.SysError("failed to update option map: " + err.Error()) } } + for _, option := range options { + if option.Key == "Ratio" { + option.Value = billingratio.AddOldRatio(oldModelRatio, oldCompletionRatio) + err := updateOptionMap(option.Key, option.Value) + if err != nil { + logger.SysError("failed to update option map: " + err.Error()) + } + err = UpdateOption(option.Key, option.Value) + if err != nil { + logger.SysError("failed to update option map: " + err.Error()) + } + logger.SysLog("ratio merged") + } + } } func SyncOptions(frequency int) { diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 80dad948..da795490 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -3,6 +3,7 @@ package ratio import ( "encoding/json" "fmt" + "strconv" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -394,6 +395,58 @@ func AddNewMissingRatio(oldRatio string) string { return string(jsonBytes) } +func AddOldRatio(oldRatio string, oldCompletionRatio string) string { + modelRatio := make(map[string]float64) + if oldRatio != "" { + err := json.Unmarshal([]byte(oldRatio), &modelRatio) + if err != nil { + logger.SysError("error unmarshalling old ratio: " + err.Error()) + return oldRatio + } + } + + completionRatio := make(map[string]float64) + if oldCompletionRatio != "" { + err := json.Unmarshal([]byte(oldCompletionRatio), &completionRatio) + if err != nil { + logger.SysError("error unmarshalling old completion ratio: " + err.Error()) + return oldCompletionRatio + } + } + + newRatio := make(map[string]Ratio) + + for k, v := range DefaultRatio { + if _, ok := newRatio[k]; !ok { + newRatio[k] = v + } + } + + for k, v := range modelRatio { + if _, ok := DefaultRatio[k]; ok { + continue + } + modelName, channelType := SplitModelName(k) + ratio := Ratio{} + ratio.Input = v + + if val, ok := completionRatio[k]; ok { + ratio.Output = v * val + } else { + ratio.Output = v * GetCompletionRatio(modelName, channelType) + } + + newRatio[k] = ratio + } + + jsonBytes, err := json.Marshal(newRatio) + if err != nil { + logger.SysError("error marshalling new ratio: " + err.Error()) + return oldRatio + } + return string(jsonBytes) +} + func ModelRatio2JSONString() string { jsonBytes, err := json.Marshal(ModelRatio) if err != nil { @@ -444,6 +497,18 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &CompletionRatio) } +func SplitModelName(name string) (string, int) { + model := strings.Split(name, "(") + modelName := model[0] + channelType := 0 + if len(model) > 1 { + if v, err := strconv.Atoi(model[1]); err == nil { + channelType = v + } + } + return modelName, channelType +} + func GetCompletionRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet")