mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 07:56:38 +08:00
feat: update token encoder
This commit is contained in:
parent
4dd5233f49
commit
ecdcb379fe
@ -17,6 +17,7 @@ import (
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
var cl200kTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
@ -29,7 +30,7 @@ func InitTokenEncoders() {
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
|
||||
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
||||
}
|
||||
@ -38,7 +39,7 @@ func InitTokenEncoders() {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
if strings.HasPrefix(model, "gpt-4o") {
|
||||
tokenEncoderMap[model] = gpt4oTokenEncoder
|
||||
tokenEncoderMap[model] = cl200kTokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
}
|
||||
@ -49,21 +50,30 @@ func InitTokenEncoders() {
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
if strings.HasPrefix(model, "gpt-4o") {
|
||||
return cl200kTokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, ok := tokenEncoderMap[model]
|
||||
if ok && tokenEncoder != nil {
|
||||
return tokenEncoder
|
||||
}
|
||||
// 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型
|
||||
if ok {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = defaultTokenEncoder
|
||||
tokenEncoder = getModelDefaultTokenEncoder(model)
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
// 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
|
||||
return getModelDefaultTokenEncoder(model)
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
@ -75,13 +85,13 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
if model == "glm-4v" {
|
||||
return 1047, nil
|
||||
}
|
||||
if imageUrl.Detail == "low" {
|
||||
return 85, nil
|
||||
}
|
||||
// 同步One API的图片计费逻辑
|
||||
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
||||
imageUrl.Detail = "high"
|
||||
}
|
||||
if imageUrl.Detail == "low" {
|
||||
return 85, nil
|
||||
}
|
||||
var config image.Config
|
||||
var err error
|
||||
var format string
|
||||
|
Loading…
Reference in New Issue
Block a user