diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 177d853..1873cab 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -2,10 +2,18 @@ package controller import ( "encoding/json" + "errors" "fmt" + "github.com/chai2010/webp" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" "io" + "log" + "math" "net/http" "one-api/common" "strconv" @@ -63,6 +71,64 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } +func getImageToken(imageUrl MessageImageUrl) (int, error) { + if imageUrl.Detail == "low" { + return 85, nil + } + + response, err := http.Get(imageUrl.Url) + if err != nil { + fmt.Println("Error: Failed to get the URL") + return 0, err + } + + defer response.Body.Close() + + // 限制读取的字节数,防止下载整个图片 + limitReader := io.LimitReader(response.Body, 8192) + + // 读取图片的头部信息来获取图片尺寸 + config, _, err := image.DecodeConfig(limitReader) + if err != nil { + common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) + config, err = webp.DecodeConfig(limitReader) + if err != nil { + common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) + } + } + if config.Width == 0 || config.Height == 0 { + return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error())) + } + if config.Width < 512 && config.Height < 512 { + if imageUrl.Detail == "auto" || imageUrl.Detail == "" { + return 85, nil + } + } + + shortSide := config.Width + otherSide := config.Height + log.Printf("width: %d, height: %d", config.Width, config.Height) + // 缩放倍数 + scale := 1.0 + if config.Height < shortSide { + shortSide = config.Height + otherSide = config.Width + } + + // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768 + if shortSide > 768 { + scale = float64(shortSide) / 768 + shortSide = 768 + } + // 将另一边按照相同的比例缩小,向上取整 + otherSide = int(math.Ceil(float64(otherSide) / scale)) + log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale) + // 计算图片的token数量(边的长度除以512,向上取整) + tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512) + log.Printf("tiles: %d", tiles) + return tiles*170 + 85, nil +} + func countTokenMessages(messages []Message, model string) (int, error) { //recover when panic tokenEncoder := getTokenEncoder(model) @@ -100,8 +166,12 @@ func countTokenMessages(messages []Message, model string) (int, error) { } else { for _, m := range arrayContent { if m.Type == "image_url" { - //TODO: getImageToken - tokenNum += 1000 + imageTokenNum, err := getImageToken(m.ImageUrl) + if err != nil { + return 0, err + } + tokenNum += imageTokenNum + log.Printf("image token num: %d", imageTokenNum) } else { tokenNum += getTokenNum(tokenEncoder, m.Text) } diff --git a/go.mod b/go.mod index a82121b..3a75341 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ module one-api go 1.18 require ( + github.com/chai2010/webp v1.1.1 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 diff --git a/go.sum b/go.sum index 2d64620..6e7f963 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk= +github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=