feat: 统一错误提示

This commit is contained in:
CaIon 2024-03-20 20:36:55 +08:00
parent eb6257a8d8
commit a232afe9fd
8 changed files with 35 additions and 25 deletions

View File

@ -313,7 +313,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
}, nil }, nil
} }
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse) fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive()) completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
} }

View File

@ -257,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil }, nil
} }
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive()) completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
usage := dto.Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,

View File

@ -154,7 +154,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
stringContent := string(choice.Message.Content) stringContent := string(choice.Message.Content)
ctkm, _ := service.CountTokenText(stringContent, model, false) ctkm, _, _ := service.CountTokenText(stringContent, model, false)
completionTokens += ctkm completionTokens += ctkm
if checkSensitive { if checkSensitive {
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)

View File

@ -157,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil }, nil
} }
fullTextResponse := responsePaLM2OpenAI(&palmResponse) fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive()) completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
usage := dto.Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,

View File

@ -67,7 +67,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
promptTokens := 0 promptTokens := 0
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if strings.HasPrefix(audioRequest.Model, "tts-1") { if strings.HasPrefix(audioRequest.Model, "tts-1") {
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive()) promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
} }
@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if strings.HasPrefix(audioRequest.Model, "tts-1") { if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = promptTokens quota = promptTokens
} else { } else {
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive()) quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
} }
quota = int(float64(quota) * ratio) quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {

View File

@ -98,10 +98,13 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
var ratio float64 var ratio float64
var modelRatio float64 var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest) //err := service.SensitiveWordsCheck(textRequest)
promptTokens, err := getPromptTokens(textRequest, relayInfo) promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误 // count messages token error 计算promptTokens错误
if err != nil { if err != nil {
if sensitiveTrigger {
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
}
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
} }
@ -180,25 +183,26 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return nil return nil
} }
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) {
var promptTokens int var promptTokens int
var err error var err error
var sensitiveTrigger bool
checkSensitive := constant.ShouldCheckPromptSensitive() checkSensitive := constant.ShouldCheckPromptSensitive()
switch info.RelayMode { switch info.RelayMode {
case relayconstant.RelayModeChatCompletions: case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
case relayconstant.RelayModeCompletions: case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
case relayconstant.RelayModeModerations: case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
default: default:
err = errors.New("unknown relay mode") err = errors.New("unknown relay mode")
promptTokens = 0 promptTokens = 0
} }
info.PromptTokens = promptTokens info.PromptTokens = promptTokens
return promptTokens, err return promptTokens, err, sensitiveTrigger
} }
// 预扣费并返回用户剩余配额 // 预扣费并返回用户剩余配额

View File

@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
return tiles*170 + 85, nil return tiles*170 + 85, nil
} }
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) { func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
//recover when panic //recover when panic
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
// Reference: // Reference:
@ -142,13 +142,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
if err := json.Unmarshal(message.Content, &arrayContent); err != nil { if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
var stringContent string var stringContent string
if err := json.Unmarshal(message.Content, &stringContent); err != nil { if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err return 0, err, false
} else { } else {
if checkSensitive { if checkSensitive {
contains, words := SensitiveWordContains(stringContent) contains, words := SensitiveWordContains(stringContent)
if contains { if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err return 0, err, true
} }
} }
tokenNum += getTokenNum(tokenEncoder, stringContent) tokenNum += getTokenNum(tokenEncoder, stringContent)
@ -181,7 +181,7 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
imageTokenNum, err = getImageToken(&imageUrl) imageTokenNum, err = getImageToken(&imageUrl)
} }
if err != nil { if err != nil {
return 0, err return 0, err, false
} }
} }
tokenNum += imageTokenNum tokenNum += imageTokenNum
@ -194,10 +194,10 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
} }
} }
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil return tokenNum, nil, false
} }
func CountTokenInput(input any, model string, check bool) (int, error) { func CountTokenInput(input any, model string, check bool) (int, error, bool) {
switch v := input.(type) { switch v := input.(type) {
case string: case string:
return CountTokenText(v, model, check) return CountTokenText(v, model, check)
@ -208,26 +208,32 @@ func CountTokenInput(input any, model string, check bool) (int, error) {
} }
return CountTokenText(text, model, check) return CountTokenText(text, model, check)
} }
return 0, errors.New("unsupported input type") return 0, errors.New("unsupported input type"), false
} }
func CountAudioToken(text string, model string, check bool) (int, error) { func CountAudioToken(text string, model string, check bool) (int, error, bool) {
if strings.HasPrefix(model, "tts") { if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil contains, words := SensitiveWordContains(text)
if contains {
return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
}
return utf8.RuneCountInString(text), nil, false
} else { } else {
return CountTokenText(text, model, check) return CountTokenText(text, model, check)
} }
} }
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量 // CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTokenText(text string, model string, check bool) (int, error) { func CountTokenText(text string, model string, check bool) (int, error, bool) {
var err error var err error
var trigger bool
if check { if check {
contains, words := SensitiveWordContains(text) contains, words := SensitiveWordContains(text)
if contains { if contains {
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")) err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
trigger = true
} }
} }
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err return getTokenNum(tokenEncoder, text), err, trigger
} }

View File

@ -19,7 +19,7 @@ import (
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = promptTokens usage.PromptTokens = promptTokens
ctkm, err := CountTokenText(responseText, modeName, false) ctkm, err, _ := CountTokenText(responseText, modeName, false)
usage.CompletionTokens = ctkm usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err return usage, err