mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	fix: panic when get model ratio (close #392)
This commit is contained in:
		@@ -3,6 +3,7 @@ package common
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// from songquanpeng/one-api
 | 
			
		||||
@@ -182,8 +183,14 @@ var defaultModelPrice = map[string]float64{
 | 
			
		||||
	"swap_face":         0.05,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var modelPrice map[string]float64 = nil
 | 
			
		||||
var modelRatio map[string]float64 = nil
 | 
			
		||||
var (
 | 
			
		||||
	modelPriceMap      = make(map[string]float64)
 | 
			
		||||
	modelPriceMapMutex = sync.RWMutex{}
 | 
			
		||||
)
 | 
			
		||||
var (
 | 
			
		||||
	modelRatioMap      map[string]float64 = nil
 | 
			
		||||
	modelRatioMapMutex                    = sync.RWMutex{}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var CompletionRatio map[string]float64 = nil
 | 
			
		||||
var defaultCompletionRatio = map[string]float64{
 | 
			
		||||
@@ -191,11 +198,18 @@ var defaultCompletionRatio = map[string]float64{
 | 
			
		||||
	"gpt-4-all":     2,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ModelPrice2JSONString() string {
 | 
			
		||||
	if modelPrice == nil {
 | 
			
		||||
		modelPrice = defaultModelPrice
 | 
			
		||||
func GetModelPriceMap() map[string]float64 {
 | 
			
		||||
	modelPriceMapMutex.Lock()
 | 
			
		||||
	defer modelPriceMapMutex.Unlock()
 | 
			
		||||
	if modelPriceMap == nil {
 | 
			
		||||
		modelPriceMap = defaultModelPrice
 | 
			
		||||
	}
 | 
			
		||||
	jsonBytes, err := json.Marshal(modelPrice)
 | 
			
		||||
	return modelPriceMap
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ModelPrice2JSONString() string {
 | 
			
		||||
	GetModelPriceMap()
 | 
			
		||||
	jsonBytes, err := json.Marshal(modelPriceMap)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		SysError("error marshalling model price: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
@@ -203,19 +217,19 @@ func ModelPrice2JSONString() string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateModelPriceByJSONString(jsonStr string) error {
 | 
			
		||||
	modelPrice = make(map[string]float64)
 | 
			
		||||
	return json.Unmarshal([]byte(jsonStr), &modelPrice)
 | 
			
		||||
	modelPriceMapMutex.Lock()
 | 
			
		||||
	defer modelPriceMapMutex.Unlock()
 | 
			
		||||
	modelPriceMap = make(map[string]float64)
 | 
			
		||||
	return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
 | 
			
		||||
func GetModelPrice(name string, printErr bool) (float64, bool) {
 | 
			
		||||
	if modelPrice == nil {
 | 
			
		||||
		modelPrice = defaultModelPrice
 | 
			
		||||
	}
 | 
			
		||||
	GetModelPriceMap()
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-4-gizmo") {
 | 
			
		||||
		name = "gpt-4-gizmo-*"
 | 
			
		||||
	}
 | 
			
		||||
	price, ok := modelPrice[name]
 | 
			
		||||
	price, ok := modelPriceMap[name]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		if printErr {
 | 
			
		||||
			SysError("model price not found: " + name)
 | 
			
		||||
@@ -225,18 +239,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
 | 
			
		||||
	return price, true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetModelPriceMap() map[string]float64 {
 | 
			
		||||
	if modelPrice == nil {
 | 
			
		||||
		modelPrice = defaultModelPrice
 | 
			
		||||
func GetModelRatioMap() map[string]float64 {
 | 
			
		||||
	modelRatioMapMutex.Lock()
 | 
			
		||||
	defer modelRatioMapMutex.Unlock()
 | 
			
		||||
	if modelRatioMap == nil {
 | 
			
		||||
		modelRatioMap = defaultModelRatio
 | 
			
		||||
	}
 | 
			
		||||
	return modelPrice
 | 
			
		||||
	return modelRatioMap
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ModelRatio2JSONString() string {
 | 
			
		||||
	if modelRatio == nil {
 | 
			
		||||
		modelRatio = defaultModelRatio
 | 
			
		||||
	}
 | 
			
		||||
	jsonBytes, err := json.Marshal(modelRatio)
 | 
			
		||||
	GetModelRatioMap()
 | 
			
		||||
	jsonBytes, err := json.Marshal(modelRatioMap)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		SysError("error marshalling model ratio: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
@@ -244,18 +258,18 @@ func ModelRatio2JSONString() string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateModelRatioByJSONString(jsonStr string) error {
 | 
			
		||||
	modelRatio = make(map[string]float64)
 | 
			
		||||
	return json.Unmarshal([]byte(jsonStr), &modelRatio)
 | 
			
		||||
	modelRatioMapMutex.Lock()
 | 
			
		||||
	defer modelRatioMapMutex.Unlock()
 | 
			
		||||
	modelRatioMap = make(map[string]float64)
 | 
			
		||||
	return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetModelRatio(name string) float64 {
 | 
			
		||||
	if modelRatio == nil {
 | 
			
		||||
		modelRatio = defaultModelRatio
 | 
			
		||||
	}
 | 
			
		||||
	GetModelRatioMap()
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-4-gizmo") {
 | 
			
		||||
		name = "gpt-4-gizmo-*"
 | 
			
		||||
	}
 | 
			
		||||
	ratio, ok := modelRatio[name]
 | 
			
		||||
	ratio, ok := modelRatioMap[name]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		SysError("model ratio not found: " + name)
 | 
			
		||||
		return 30
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user