feat: 初步兼容敏感词过滤

This commit is contained in:
CaIon
2024-03-20 17:07:42 +08:00
parent bec21ade9d
commit 7a663d26ec
22 changed files with 293 additions and 66 deletions

View File

@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
return tiles*170 + 85, nil
}
func CountTokenMessages(messages []dto.Message, model string) (int, error) {
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -144,6 +144,13 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err
} else {
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
@@ -190,29 +197,37 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
return tokenNum, nil
}
func CountTokenInput(input any, model string) int {
func CountTokenInput(input any, model string, check bool) (int, error) {
switch v := input.(type) {
case string:
return CountTokenText(v, model)
return CountTokenText(v, model, check)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTokenText(text, model)
return CountTokenText(text, model, check)
}
return 0
return 0, errors.New("unsupported input type")
}
func CountAudioToken(text string, model string) int {
func CountAudioToken(text string, model string, check bool) (int, error) {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text)
return utf8.RuneCountInString(text), nil
} else {
return CountTokenText(text, model)
return CountTokenText(text, model, check)
}
}
func CountTokenText(text string, model string) int {
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTokenText(text string, model string, check bool) (int, error) {
var err error
if check {
contains, words := SensitiveWordContains(text)
if contains {
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
}
}
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
return getTokenNum(tokenEncoder, text), err
}