diff --git a/service/token_counter.go b/service/token_counter.go index cdca1fd..b99fc20 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -17,6 +17,7 @@ import ( // tokenEncoderMap won't grow after initialization var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var defaultTokenEncoder *tiktoken.Tiktoken +var cl200kTokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { common.SysLog("initializing token encoders") @@ -29,7 +30,7 @@ func InitTokenEncoders() { if err != nil { common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } - gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o") + cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o") if err != nil { common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) } @@ -38,7 +39,7 @@ func InitTokenEncoders() { tokenEncoderMap[model] = gpt35TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { if strings.HasPrefix(model, "gpt-4o") { - tokenEncoderMap[model] = gpt4oTokenEncoder + tokenEncoderMap[model] = cl200kTokenEncoder } else { tokenEncoderMap[model] = gpt4TokenEncoder } @@ -49,21 +50,30 @@ func InitTokenEncoders() { common.SysLog("token encoders initialized") } +func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken { + if strings.HasPrefix(model, "gpt-4o") { + return cl200kTokenEncoder + } + return defaultTokenEncoder +} + func getTokenEncoder(model string) *tiktoken.Tiktoken { tokenEncoder, ok := tokenEncoderMap[model] if ok && tokenEncoder != nil { return tokenEncoder } + // 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型 if ok { tokenEncoder, err := tiktoken.EncodingForModel(model) if err != nil { common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) - tokenEncoder = defaultTokenEncoder + tokenEncoder = getModelDefaultTokenEncoder(model) } tokenEncoderMap[model] = tokenEncoder return tokenEncoder } - return defaultTokenEncoder + // 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder + return getModelDefaultTokenEncoder(model) } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { @@ -75,13 +85,13 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in if model == "glm-4v" { return 1047, nil } + if imageUrl.Detail == "low" { + return 85, nil + } // 同步One API的图片计费逻辑 if imageUrl.Detail == "auto" || imageUrl.Detail == "" { imageUrl.Detail = "high" } - if imageUrl.Detail == "low" { - return 85, nil - } var config image.Config var err error var format string