From 783e8fd74a06ef1fecd545edc59668b09cda694b Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Tue, 23 Apr 2024 23:51:27 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E8=AE=A1?= =?UTF-8?q?=E8=B4=B9=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- service/token_counter.go | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/service/token_counter.go b/service/token_counter.go index c1bac75..52828e9 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -128,7 +128,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check var openaiTools []dto.OpenAITools err := json.Unmarshal(toolsData, &openaiTools) if err != nil { - return 0, errors.New(fmt.Sprintf("count tools token fail: %s", err.Error())), false + return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false } countStr := "" for _, tool := range openaiTools { @@ -173,26 +173,23 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo tokenNum += tokensPerMessage tokenNum += getTokenNum(tokenEncoder, message.Role) if len(message.Content) > 0 { - var arrayContent []dto.MediaMessage - if err := json.Unmarshal(message.Content, &arrayContent); err != nil { - var stringContent string - if err := json.Unmarshal(message.Content, &stringContent); err != nil { - return 0, err, false - } else { - if checkSensitive { - contains, words := SensitiveWordContains(stringContent) - if contains { - err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) - return 0, err, true - } - } - tokenNum += getTokenNum(tokenEncoder, stringContent) - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) + if message.IsStringContent() { + stringContent := message.StringContent() + if checkSensitive { + contains, words := SensitiveWordContains(stringContent) + if contains { + err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) + return 0, err, true } } + tokenNum += getTokenNum(tokenEncoder, stringContent) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } } else { + var err error + arrayContent := message.ParseContent() for _, m := range arrayContent { if m.Type == "image_url" { var imageTokenNum int