feat: update tiktoken

This commit is contained in:
1808837298@qq.com
2024-05-30 21:39:58 +08:00
parent a31247ecaa
commit d2a0d9f73b
4 changed files with 21 additions and 14 deletions

View File

@@ -4,7 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/linux-do/tiktoken-go"
"github.com/pkoukk/tiktoken-go"
"image"
"log"
"math"
@@ -26,10 +26,13 @@ func InitTokenEncoders() {
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
}
for model, _ := range common.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
@@ -72,6 +75,10 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
if model == "glm-4v" {
return 1047, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high"
}
if imageUrl.Detail == "low" {
return 85, nil
}
@@ -92,14 +99,14 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
if config.Width == 0 || config.Height == 0 {
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
}
// TODO: 适配官方auto计费
if config.Width < 512 && config.Height < 512 {
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
// 如果图片尺寸小于512强制使用low
imageUrl.Detail = "low"
return 85, nil
}
}
//// TODO: 适配官方auto计费
//if config.Width < 512 && config.Height < 512 {
// if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
// // 如果图片尺寸小于512强制使用low
// imageUrl.Detail = "low"
// return 85, nil
// }
//}
shortSide := config.Width
otherSide := config.Height