chore: token counter

This commit is contained in:
CaIon 2024-05-18 15:14:49 +08:00
parent de81eba90b
commit a3de309175

View File

@ -67,7 +67,11 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil)) return len(tokenEncoder.Encode(text, nil, nil))
} }
func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
// TODO: 非流模式下不计算图片token数量
if model == "glm-4v" {
return 1047, nil
}
if imageUrl.Detail == "low" { if imageUrl.Detail == "low" {
return 85, nil return 85, nil
} }
@ -123,7 +127,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) { func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
tkm := 0 tkm := 0
msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive) msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
if err != nil { if err != nil {
return 0, err, b return 0, err, b
} }
@ -156,7 +160,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
return tkm, nil, false return tkm, nil, false
} }
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) { func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
//recover when panic //recover when panic
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
// Reference: // Reference:
@ -193,19 +197,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
tokenNum += getTokenNum(tokenEncoder, *message.Name) tokenNum += getTokenNum(tokenEncoder, *message.Name)
} }
} else { } else {
var err error
arrayContent := message.ParseContent() arrayContent := message.ParseContent()
for _, m := range arrayContent { for _, m := range arrayContent {
if m.Type == "image_url" { if m.Type == "image_url" {
var imageTokenNum int imageUrl := m.ImageUrl.(dto.MessageImageUrl)
if model == "glm-4v" { imageTokenNum, err := getImageToken(&imageUrl, model, stream)
imageTokenNum = 1047 if err != nil {
} else { return 0, err, false
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err = getImageToken(&imageUrl)
if err != nil {
return 0, err, false
}
} }
tokenNum += imageTokenNum tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum) log.Printf("image token num: %d", imageTokenNum)