fix image token calculate

This commit is contained in:
CaIon 2023-11-17 20:32:11 +08:00
parent 7e0d2606c3
commit e5c2524f15
3 changed files with 75 additions and 2 deletions

View File

@ -2,10 +2,18 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/chai2010/webp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go" "github.com/pkoukk/tiktoken-go"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io" "io"
"log"
"math"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strconv" "strconv"
@ -63,6 +71,64 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil)) return len(tokenEncoder.Encode(text, nil, nil))
} }
func getImageToken(imageUrl MessageImageUrl) (int, error) {
if imageUrl.Detail == "low" {
return 85, nil
}
response, err := http.Get(imageUrl.Url)
if err != nil {
fmt.Println("Error: Failed to get the URL")
return 0, err
}
defer response.Body.Close()
// 限制读取的字节数,防止下载整个图片
limitReader := io.LimitReader(response.Body, 8192)
// 读取图片的头部信息来获取图片尺寸
config, _, err := image.DecodeConfig(limitReader)
if err != nil {
common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
config, err = webp.DecodeConfig(limitReader)
if err != nil {
common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
}
}
if config.Width == 0 || config.Height == 0 {
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error()))
}
if config.Width < 512 && config.Height < 512 {
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
return 85, nil
}
}
shortSide := config.Width
otherSide := config.Height
log.Printf("width: %d, height: %d", config.Width, config.Height)
// 缩放倍数
scale := 1.0
if config.Height < shortSide {
shortSide = config.Height
otherSide = config.Width
}
// 将最小变的尺寸缩小到768以下如果大于768则缩放到768
if shortSide > 768 {
scale = float64(shortSide) / 768
shortSide = 768
}
// 将另一边按照相同的比例缩小,向上取整
otherSide = int(math.Ceil(float64(otherSide) / scale))
log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
// 计算图片的token数量(边的长度除以512向上取整)
tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
log.Printf("tiles: %d", tiles)
return tiles*170 + 85, nil
}
func countTokenMessages(messages []Message, model string) (int, error) { func countTokenMessages(messages []Message, model string) (int, error) {
//recover when panic //recover when panic
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
@ -100,8 +166,12 @@ func countTokenMessages(messages []Message, model string) (int, error) {
} else { } else {
for _, m := range arrayContent { for _, m := range arrayContent {
if m.Type == "image_url" { if m.Type == "image_url" {
//TODO: getImageToken imageTokenNum, err := getImageToken(m.ImageUrl)
tokenNum += 1000 if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else { } else {
tokenNum += getTokenNum(tokenEncoder, m.Text) tokenNum += getTokenNum(tokenEncoder, m.Text)
} }

1
go.mod
View File

@ -4,6 +4,7 @@ module one-api
go 1.18 go 1.18
require ( require (
github.com/chai2010/webp v1.1.1
github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5

2
go.sum
View File

@ -3,6 +3,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=