mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-18 00:16:37 +08:00
feat: 统一错误提示
This commit is contained in:
parent
eb6257a8d8
commit
a232afe9fd
@ -313,7 +313,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
|
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
|
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
@ -257,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||||
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
|
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
@ -154,7 +154,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range textResponse.Choices {
|
for _, choice := range textResponse.Choices {
|
||||||
stringContent := string(choice.Message.Content)
|
stringContent := string(choice.Message.Content)
|
||||||
ctkm, _ := service.CountTokenText(stringContent, model, false)
|
ctkm, _, _ := service.CountTokenText(stringContent, model, false)
|
||||||
completionTokens += ctkm
|
completionTokens += ctkm
|
||||||
if checkSensitive {
|
if checkSensitive {
|
||||||
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
|
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
|
||||||
|
@ -157,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
|
completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
@ -67,7 +67,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||||
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
|
promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||||
quota = promptTokens
|
quota = promptTokens
|
||||||
} else {
|
} else {
|
||||||
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
|
quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
|
||||||
}
|
}
|
||||||
quota = int(float64(quota) * ratio)
|
quota = int(float64(quota) * ratio)
|
||||||
if ratio != 0 && quota <= 0 {
|
if ratio != 0 && quota <= 0 {
|
||||||
|
@ -98,10 +98,13 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
var ratio float64
|
var ratio float64
|
||||||
var modelRatio float64
|
var modelRatio float64
|
||||||
//err := service.SensitiveWordsCheck(textRequest)
|
//err := service.SensitiveWordsCheck(textRequest)
|
||||||
promptTokens, err := getPromptTokens(textRequest, relayInfo)
|
promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo)
|
||||||
|
|
||||||
// count messages token error 计算promptTokens错误
|
// count messages token error 计算promptTokens错误
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if sensitiveTrigger {
|
||||||
|
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||||
|
}
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,25 +183,26 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
|
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) {
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
var err error
|
var err error
|
||||||
|
var sensitiveTrigger bool
|
||||||
checkSensitive := constant.ShouldCheckPromptSensitive()
|
checkSensitive := constant.ShouldCheckPromptSensitive()
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
|
promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
|
||||||
case relayconstant.RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
||||||
case relayconstant.RelayModeModerations:
|
case relayconstant.RelayModeModerations:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
|
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
|
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
|
||||||
default:
|
default:
|
||||||
err = errors.New("unknown relay mode")
|
err = errors.New("unknown relay mode")
|
||||||
promptTokens = 0
|
promptTokens = 0
|
||||||
}
|
}
|
||||||
info.PromptTokens = promptTokens
|
info.PromptTokens = promptTokens
|
||||||
return promptTokens, err
|
return promptTokens, err, sensitiveTrigger
|
||||||
}
|
}
|
||||||
|
|
||||||
// 预扣费并返回用户剩余配额
|
// 预扣费并返回用户剩余配额
|
||||||
|
@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
|||||||
return tiles*170 + 85, nil
|
return tiles*170 + 85, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
|
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
||||||
//recover when panic
|
//recover when panic
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
// Reference:
|
// Reference:
|
||||||
@ -142,13 +142,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
|||||||
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
|
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
|
||||||
var stringContent string
|
var stringContent string
|
||||||
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
||||||
return 0, err
|
return 0, err, false
|
||||||
} else {
|
} else {
|
||||||
if checkSensitive {
|
if checkSensitive {
|
||||||
contains, words := SensitiveWordContains(stringContent)
|
contains, words := SensitiveWordContains(stringContent)
|
||||||
if contains {
|
if contains {
|
||||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
||||||
return 0, err
|
return 0, err, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||||
@ -181,7 +181,7 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
|||||||
imageTokenNum, err = getImageToken(&imageUrl)
|
imageTokenNum, err = getImageToken(&imageUrl)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tokenNum += imageTokenNum
|
tokenNum += imageTokenNum
|
||||||
@ -194,10 +194,10 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||||
return tokenNum, nil
|
return tokenNum, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenInput(input any, model string, check bool) (int, error) {
|
func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
return CountTokenText(v, model, check)
|
return CountTokenText(v, model, check)
|
||||||
@ -208,26 +208,32 @@ func CountTokenInput(input any, model string, check bool) (int, error) {
|
|||||||
}
|
}
|
||||||
return CountTokenText(text, model, check)
|
return CountTokenText(text, model, check)
|
||||||
}
|
}
|
||||||
return 0, errors.New("unsupported input type")
|
return 0, errors.New("unsupported input type"), false
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountAudioToken(text string, model string, check bool) (int, error) {
|
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
return utf8.RuneCountInString(text), nil
|
contains, words := SensitiveWordContains(text)
|
||||||
|
if contains {
|
||||||
|
return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
|
||||||
|
}
|
||||||
|
return utf8.RuneCountInString(text), nil, false
|
||||||
} else {
|
} else {
|
||||||
return CountTokenText(text, model, check)
|
return CountTokenText(text, model, check)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||||
func CountTokenText(text string, model string, check bool) (int, error) {
|
func CountTokenText(text string, model string, check bool) (int, error, bool) {
|
||||||
var err error
|
var err error
|
||||||
|
var trigger bool
|
||||||
if check {
|
if check {
|
||||||
contains, words := SensitiveWordContains(text)
|
contains, words := SensitiveWordContains(text)
|
||||||
if contains {
|
if contains {
|
||||||
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
|
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
|
||||||
|
trigger = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
return getTokenNum(tokenEncoder, text), err
|
return getTokenNum(tokenEncoder, text), err, trigger
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ import (
|
|||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
ctkm, err := CountTokenText(responseText, modeName, false)
|
ctkm, err, _ := CountTokenText(responseText, modeName, false)
|
||||||
usage.CompletionTokens = ctkm
|
usage.CompletionTokens = ctkm
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage, err
|
return usage, err
|
||||||
|
Loading…
Reference in New Issue
Block a user