diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 1027faa..a56c1bb 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -313,7 +313,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT }, nil } fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive()) + completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive()) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b199178..31badd8 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -257,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive()) + completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive()) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 9ec4260..7e36861 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -154,7 +154,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model completionTokens := 0 for _, choice := range textResponse.Choices { stringContent := string(choice.Message.Content) - ctkm, _ := service.CountTokenText(stringContent, model, false) + ctkm, _, _ := service.CountTokenText(stringContent, model, false) completionTokens += ctkm if checkSensitive { sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index b3607c0..4028269 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -157,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive()) + completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive()) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/relay-audio.go b/relay/relay-audio.go index d68550e..1c0f868 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -67,7 +67,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if strings.HasPrefix(audioRequest.Model, "tts-1") { - promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive()) + promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive()) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } @@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if strings.HasPrefix(audioRequest.Model, "tts-1") { quota = promptTokens } else { - quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive()) + quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive()) } quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { diff --git a/relay/relay-text.go b/relay/relay-text.go index b9c1e7a..a8ba0e8 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -98,10 +98,13 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { var ratio float64 var modelRatio float64 //err := service.SensitiveWordsCheck(textRequest) - promptTokens, err := getPromptTokens(textRequest, relayInfo) + promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { + if sensitiveTrigger { + return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) + } return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } @@ -180,25 +183,26 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return nil } -func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { +func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) { var promptTokens int var err error + var sensitiveTrigger bool checkSensitive := constant.ShouldCheckPromptSensitive() switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) case relayconstant.RelayModeCompletions: - promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) case relayconstant.RelayModeModerations: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) case relayconstant.RelayModeEmbeddings: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) default: err = errors.New("unknown relay mode") promptTokens = 0 } info.PromptTokens = promptTokens - return promptTokens, err + return promptTokens, err, sensitiveTrigger } // 预扣费并返回用户剩余配额 diff --git a/service/token_counter.go b/service/token_counter.go index a04be59..4769dab 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { return tiles*170 + 85, nil } -func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) { +func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) { //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: @@ -142,13 +142,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo if err := json.Unmarshal(message.Content, &arrayContent); err != nil { var stringContent string if err := json.Unmarshal(message.Content, &stringContent); err != nil { - return 0, err + return 0, err, false } else { if checkSensitive { contains, words := SensitiveWordContains(stringContent) if contains { err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) - return 0, err + return 0, err, true } } tokenNum += getTokenNum(tokenEncoder, stringContent) @@ -181,7 +181,7 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo imageTokenNum, err = getImageToken(&imageUrl) } if err != nil { - return 0, err + return 0, err, false } } tokenNum += imageTokenNum @@ -194,10 +194,10 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo } } tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum, nil + return tokenNum, nil, false } -func CountTokenInput(input any, model string, check bool) (int, error) { +func CountTokenInput(input any, model string, check bool) (int, error, bool) { switch v := input.(type) { case string: return CountTokenText(v, model, check) @@ -208,26 +208,32 @@ func CountTokenInput(input any, model string, check bool) (int, error) { } return CountTokenText(text, model, check) } - return 0, errors.New("unsupported input type") + return 0, errors.New("unsupported input type"), false } -func CountAudioToken(text string, model string, check bool) (int, error) { +func CountAudioToken(text string, model string, check bool) (int, error, bool) { if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text), nil + contains, words := SensitiveWordContains(text) + if contains { + return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true + } + return utf8.RuneCountInString(text), nil, false } else { return CountTokenText(text, model, check) } } // CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTokenText(text string, model string, check bool) (int, error) { +func CountTokenText(text string, model string, check bool) (int, error, bool) { var err error + var trigger bool if check { contains, words := SensitiveWordContains(text) if contains { err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")) + trigger = true } } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text), err + return getTokenNum(tokenEncoder, text), err, trigger } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index 53a5c04..460ac56 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -19,7 +19,7 @@ import ( func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTokenText(responseText, modeName, false) + ctkm, err, _ := CountTokenText(responseText, modeName, false) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage, err