diff --git a/common/model-ratio.go b/common/model-ratio.go index 5de961d..568254b 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -3,6 +3,7 @@ package common import ( "encoding/json" "strings" + "sync" ) // from songquanpeng/one-api @@ -182,8 +183,14 @@ var defaultModelPrice = map[string]float64{ "swap_face": 0.05, } -var modelPrice map[string]float64 = nil -var modelRatio map[string]float64 = nil +var ( + modelPriceMap = make(map[string]float64) + modelPriceMapMutex = sync.RWMutex{} +) +var ( + modelRatioMap map[string]float64 = nil + modelRatioMapMutex = sync.RWMutex{} +) var CompletionRatio map[string]float64 = nil var defaultCompletionRatio = map[string]float64{ @@ -191,11 +198,18 @@ var defaultCompletionRatio = map[string]float64{ "gpt-4-all": 2, } -func ModelPrice2JSONString() string { - if modelPrice == nil { - modelPrice = defaultModelPrice +func GetModelPriceMap() map[string]float64 { + modelPriceMapMutex.Lock() + defer modelPriceMapMutex.Unlock() + if modelPriceMap == nil { + modelPriceMap = defaultModelPrice } - jsonBytes, err := json.Marshal(modelPrice) + return modelPriceMap +} + +func ModelPrice2JSONString() string { + GetModelPriceMap() + jsonBytes, err := json.Marshal(modelPriceMap) if err != nil { SysError("error marshalling model price: " + err.Error()) } @@ -203,19 +217,19 @@ func ModelPrice2JSONString() string { } func UpdateModelPriceByJSONString(jsonStr string) error { - modelPrice = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelPrice) + modelPriceMapMutex.Lock() + defer modelPriceMapMutex.Unlock() + modelPriceMap = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &modelPriceMap) } // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false func GetModelPrice(name string, printErr bool) (float64, bool) { - if modelPrice == nil { - modelPrice = defaultModelPrice - } + GetModelPriceMap() if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } - price, ok := modelPrice[name] + price, ok := modelPriceMap[name] if !ok { if printErr { SysError("model price not found: " + name) @@ -225,18 +239,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { return price, true } -func GetModelPriceMap() map[string]float64 { - if modelPrice == nil { - modelPrice = defaultModelPrice +func GetModelRatioMap() map[string]float64 { + modelRatioMapMutex.Lock() + defer modelRatioMapMutex.Unlock() + if modelRatioMap == nil { + modelRatioMap = defaultModelRatio } - return modelPrice + return modelRatioMap } func ModelRatio2JSONString() string { - if modelRatio == nil { - modelRatio = defaultModelRatio - } - jsonBytes, err := json.Marshal(modelRatio) + GetModelRatioMap() + jsonBytes, err := json.Marshal(modelRatioMap) if err != nil { SysError("error marshalling model ratio: " + err.Error()) } @@ -244,18 +258,18 @@ func ModelRatio2JSONString() string { } func UpdateModelRatioByJSONString(jsonStr string) error { - modelRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelRatio) + modelRatioMapMutex.Lock() + defer modelRatioMapMutex.Unlock() + modelRatioMap = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &modelRatioMap) } func GetModelRatio(name string) float64 { - if modelRatio == nil { - modelRatio = defaultModelRatio - } + GetModelRatioMap() if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } - ratio, ok := modelRatio[name] + ratio, ok := modelRatioMap[name] if !ok { SysError("model ratio not found: " + name) return 30