feat: 统一错误提示

This commit is contained in:
CaIon
2024-03-20 20:36:55 +08:00
parent eb6257a8d8
commit a232afe9fd
8 changed files with 35 additions and 25 deletions

View File

@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
return tiles*170 + 85, nil
}
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -142,13 +142,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
var stringContent string
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err
return 0, err, false
} else {
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err
return 0, err, true
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
@@ -181,7 +181,7 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
imageTokenNum, err = getImageToken(&imageUrl)
}
if err != nil {
return 0, err
return 0, err, false
}
}
tokenNum += imageTokenNum
@@ -194,10 +194,10 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil
return tokenNum, nil, false
}
func CountTokenInput(input any, model string, check bool) (int, error) {
func CountTokenInput(input any, model string, check bool) (int, error, bool) {
switch v := input.(type) {
case string:
return CountTokenText(v, model, check)
@@ -208,26 +208,32 @@ func CountTokenInput(input any, model string, check bool) (int, error) {
}
return CountTokenText(text, model, check)
}
return 0, errors.New("unsupported input type")
return 0, errors.New("unsupported input type"), false
}
func CountAudioToken(text string, model string, check bool) (int, error) {
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil
contains, words := SensitiveWordContains(text)
if contains {
return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
}
return utf8.RuneCountInString(text), nil, false
} else {
return CountTokenText(text, model, check)
}
}
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTokenText(text string, model string, check bool) (int, error) {
func CountTokenText(text string, model string, check bool) (int, error, bool) {
var err error
var trigger bool
if check {
contains, words := SensitiveWordContains(text)
if contains {
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
trigger = true
}
}
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err
return getTokenNum(tokenEncoder, text), err, trigger
}