From d2a0d9f73b1fa395c78b141fb95af347066e49f8 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Thu, 30 May 2024 21:39:58 +0800 Subject: [PATCH] feat: update tiktoken --- go.mod | 2 +- go.sum | 4 ++-- model/cache.go | 2 +- service/token_counter.go | 27 +++++++++++++++++---------- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index 81f487b..a9d4a1d 100644 --- a/go.mod +++ b/go.mod @@ -20,8 +20,8 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 github.com/jinzhu/copier v0.4.0 - github.com/linux-do/tiktoken-go v0.7.0 github.com/pkg/errors v0.9.1 + github.com/pkoukk/tiktoken-go v0.1.7 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible golang.org/x/crypto v0.21.0 diff --git a/go.sum b/go.sum index 49c71ef..a77a89c 100644 --- a/go.sum +++ b/go.sum @@ -124,8 +124,6 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/linux-do/tiktoken-go v0.7.0 h1:Kcm/miJ5gp77srtF8GQWnfq7W9kTaXEuHZg/g9IVEu8= -github.com/linux-do/tiktoken-go v0.7.0/go.mod h1:9Vkdtp0ngi4USmrdSx984iuIQ5IMr0hnUdz4jZZTJb8= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -150,6 +148,8 @@ github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNc github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= +github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= diff --git a/model/cache.go b/model/cache.go index 330e09b..2977bb6 100644 --- a/model/cache.go +++ b/model/cache.go @@ -87,7 +87,7 @@ func SyncTokenCache(frequency int) { } } else { // 如果数据库中存在,先检查redis - _, err := common.RedisGet(fmt.Sprintf("token:%s", key)) + _, err = common.RedisGet(fmt.Sprintf("token:%s", key)) if err != nil { // 如果redis中不存在,则跳过 continue diff --git a/service/token_counter.go b/service/token_counter.go index e9dab63..c540ac5 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -4,7 +4,7 @@ import ( "encoding/json" "errors" "fmt" - "github.com/linux-do/tiktoken-go" + "github.com/pkoukk/tiktoken-go" "image" "log" "math" @@ -26,10 +26,13 @@ func InitTokenEncoders() { } defaultTokenEncoder = gpt35TokenEncoder gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") - gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o") if err != nil { common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } + gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o") + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) + } for model, _ := range common.GetDefaultModelRatioMap() { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = gpt35TokenEncoder @@ -72,6 +75,10 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in if model == "glm-4v" { return 1047, nil } + // 同步One API的图片计费逻辑 + if imageUrl.Detail == "auto" || imageUrl.Detail == "" { + imageUrl.Detail = "high" + } if imageUrl.Detail == "low" { return 85, nil } @@ -92,14 +99,14 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in if config.Width == 0 || config.Height == 0 { return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) } - // TODO: 适配官方auto计费 - if config.Width < 512 && config.Height < 512 { - if imageUrl.Detail == "auto" || imageUrl.Detail == "" { - // 如果图片尺寸小于512,强制使用low - imageUrl.Detail = "low" - return 85, nil - } - } + //// TODO: 适配官方auto计费 + //if config.Width < 512 && config.Height < 512 { + // if imageUrl.Detail == "auto" || imageUrl.Detail == "" { + // // 如果图片尺寸小于512,强制使用low + // imageUrl.Detail = "low" + // return 85, nil + // } + //} shortSide := config.Width otherSide := config.Height