diff --git a/service/token_counter.go b/service/token_counter.go index 697ee15..bca3d51 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -67,7 +67,11 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { +func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { + // TODO: 非流模式下不计算图片token数量 + if model == "glm-4v" { + return 1047, nil + } if imageUrl.Detail == "low" { return 85, nil } @@ -123,7 +127,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) { tkm := 0 - msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive) + msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive) if err != nil { return 0, err, b } @@ -156,7 +160,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check return tkm, nil, false } -func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) { +func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) { //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: @@ -193,19 +197,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo tokenNum += getTokenNum(tokenEncoder, *message.Name) } } else { - var err error arrayContent := message.ParseContent() for _, m := range arrayContent { if m.Type == "image_url" { - var imageTokenNum int - if model == "glm-4v" { - imageTokenNum = 1047 - } else { - imageUrl := m.ImageUrl.(dto.MessageImageUrl) - imageTokenNum, err = getImageToken(&imageUrl) - if err != nil { - return 0, err, false - } + imageUrl := m.ImageUrl.(dto.MessageImageUrl) + imageTokenNum, err := getImageToken(&imageUrl, model, stream) + if err != nil { + return 0, err, false } tokenNum += imageTokenNum log.Printf("image token num: %d", imageTokenNum)