feat: update token encoder

This commit is contained in:
1808837298@qq.com 2024-05-30 23:15:06 +08:00
parent 4dd5233f49
commit ecdcb379fe

View File

@ -17,6 +17,7 @@ import (
// tokenEncoderMap won't grow after initialization // tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken var defaultTokenEncoder *tiktoken.Tiktoken
var cl200kTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() { func InitTokenEncoders() {
common.SysLog("initializing token encoders") common.SysLog("initializing token encoders")
@ -29,7 +30,7 @@ func InitTokenEncoders() {
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o") cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
} }
@ -38,7 +39,7 @@ func InitTokenEncoders() {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {
if strings.HasPrefix(model, "gpt-4o") { if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = gpt4oTokenEncoder tokenEncoderMap[model] = cl200kTokenEncoder
} else { } else {
tokenEncoderMap[model] = gpt4TokenEncoder tokenEncoderMap[model] = gpt4TokenEncoder
} }
@ -49,21 +50,30 @@ func InitTokenEncoders() {
common.SysLog("token encoders initialized") common.SysLog("token encoders initialized")
} }
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
if strings.HasPrefix(model, "gpt-4o") {
return cl200kTokenEncoder
}
return defaultTokenEncoder
}
func getTokenEncoder(model string) *tiktoken.Tiktoken { func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model] tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil { if ok && tokenEncoder != nil {
return tokenEncoder return tokenEncoder
} }
// 如果ok即model在tokenEncoderMap中但是tokenEncoder为nil说明可能是自定义模型
if ok { if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model) tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder tokenEncoder = getModelDefaultTokenEncoder(model)
} }
tokenEncoderMap[model] = tokenEncoder tokenEncoderMap[model] = tokenEncoder
return tokenEncoder return tokenEncoder
} }
return defaultTokenEncoder // 如果model不在tokenEncoderMap中直接返回默认的tokenEncoder
return getModelDefaultTokenEncoder(model)
} }
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
@ -75,13 +85,13 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
if model == "glm-4v" { if model == "glm-4v" {
return 1047, nil return 1047, nil
} }
if imageUrl.Detail == "low" {
return 85, nil
}
// 同步One API的图片计费逻辑 // 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" { if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high" imageUrl.Detail = "high"
} }
if imageUrl.Detail == "low" {
return 85, nil
}
var config image.Config var config image.Config
var err error var err error
var format string var format string