fix: panic when get model ratio (close #392)

This commit is contained in:
CalciumIon 2024-07-27 18:09:09 +08:00
parent b7bc205b73
commit b16e6bf423

View File

@ -3,6 +3,7 @@ package common
import ( import (
"encoding/json" "encoding/json"
"strings" "strings"
"sync"
) )
// from songquanpeng/one-api // from songquanpeng/one-api
@ -182,8 +183,14 @@ var defaultModelPrice = map[string]float64{
"swap_face": 0.05, "swap_face": 0.05,
} }
var modelPrice map[string]float64 = nil var (
var modelRatio map[string]float64 = nil modelPriceMap = make(map[string]float64)
modelPriceMapMutex = sync.RWMutex{}
)
var (
modelRatioMap map[string]float64 = nil
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{ var defaultCompletionRatio = map[string]float64{
@ -191,11 +198,18 @@ var defaultCompletionRatio = map[string]float64{
"gpt-4-all": 2, "gpt-4-all": 2,
} }
func ModelPrice2JSONString() string { func GetModelPriceMap() map[string]float64 {
if modelPrice == nil { modelPriceMapMutex.Lock()
modelPrice = defaultModelPrice defer modelPriceMapMutex.Unlock()
if modelPriceMap == nil {
modelPriceMap = defaultModelPrice
} }
jsonBytes, err := json.Marshal(modelPrice) return modelPriceMap
}
func ModelPrice2JSONString() string {
GetModelPriceMap()
jsonBytes, err := json.Marshal(modelPriceMap)
if err != nil { if err != nil {
SysError("error marshalling model price: " + err.Error()) SysError("error marshalling model price: " + err.Error())
} }
@ -203,19 +217,19 @@ func ModelPrice2JSONString() string {
} }
func UpdateModelPriceByJSONString(jsonStr string) error { func UpdateModelPriceByJSONString(jsonStr string) error {
modelPrice = make(map[string]float64) modelPriceMapMutex.Lock()
return json.Unmarshal([]byte(jsonStr), &modelPrice) defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
} }
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false // GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) { func GetModelPrice(name string, printErr bool) (float64, bool) {
if modelPrice == nil { GetModelPriceMap()
modelPrice = defaultModelPrice
}
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} }
price, ok := modelPrice[name] price, ok := modelPriceMap[name]
if !ok { if !ok {
if printErr { if printErr {
SysError("model price not found: " + name) SysError("model price not found: " + name)
@ -225,18 +239,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true return price, true
} }
func GetModelPriceMap() map[string]float64 { func GetModelRatioMap() map[string]float64 {
if modelPrice == nil { modelRatioMapMutex.Lock()
modelPrice = defaultModelPrice defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
} }
return modelPrice return modelRatioMap
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {
if modelRatio == nil { GetModelRatioMap()
modelRatio = defaultModelRatio jsonBytes, err := json.Marshal(modelRatioMap)
}
jsonBytes, err := json.Marshal(modelRatio)
if err != nil { if err != nil {
SysError("error marshalling model ratio: " + err.Error()) SysError("error marshalling model ratio: " + err.Error())
} }
@ -244,18 +258,18 @@ func ModelRatio2JSONString() string {
} }
func UpdateModelRatioByJSONString(jsonStr string) error { func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatio = make(map[string]float64) modelRatioMapMutex.Lock()
return json.Unmarshal([]byte(jsonStr), &modelRatio) defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
} }
func GetModelRatio(name string) float64 { func GetModelRatio(name string) float64 {
if modelRatio == nil { GetModelRatioMap()
modelRatio = defaultModelRatio
}
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} }
ratio, ok := modelRatio[name] ratio, ok := modelRatioMap[name]
if !ok { if !ok {
SysError("model ratio not found: " + name) SysError("model ratio not found: " + name)
return 30 return 30