Merge remote-tracking branch 'upstream/main'

This commit is contained in:
wozulong
2024-05-19 16:03:53 +08:00
12 changed files with 120 additions and 39 deletions

View File

@@ -69,7 +69,11 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
// TODO: 非流模式下不计算图片token数量
if model == "glm-4v" {
return 1047, nil
}
if imageUrl.Detail == "low" {
return 85, nil
}
@@ -125,7 +129,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
tkm := 0
msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
if err != nil {
return 0, err, b
}
@@ -158,7 +162,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
return tkm, nil, false
}
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -195,19 +199,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
} else {
var err error
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == "image_url" {
var imageTokenNum int
if model == "glm-4v" {
imageTokenNum = 1047
} else {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err = getImageToken(&imageUrl)
if err != nil {
return 0, err, false
}
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
if err != nil {
return 0, err, false
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)