mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-12 00:53:41 +08:00
feat: 初步兼容敏感词过滤
This commit is contained in:
@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
||||
return tiles*170 + 85, nil
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []dto.Message, model string) (int, error) {
|
||||
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
@@ -144,6 +144,13 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
|
||||
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
if checkSensitive {
|
||||
contains, words := SensitiveWordContains(stringContent)
|
||||
if contains {
|
||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
@@ -190,29 +197,37 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenInput(input any, model string) int {
|
||||
func CountTokenInput(input any, model string, check bool) (int, error) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTokenText(v, model)
|
||||
return CountTokenText(v, model, check)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return CountTokenText(text, model)
|
||||
return CountTokenText(text, model, check)
|
||||
}
|
||||
return 0
|
||||
return 0, errors.New("unsupported input type")
|
||||
}
|
||||
|
||||
func CountAudioToken(text string, model string) int {
|
||||
func CountAudioToken(text string, model string, check bool) (int, error) {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text)
|
||||
return utf8.RuneCountInString(text), nil
|
||||
} else {
|
||||
return CountTokenText(text, model)
|
||||
return CountTokenText(text, model, check)
|
||||
}
|
||||
}
|
||||
|
||||
func CountTokenText(text string, model string) int {
|
||||
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTokenText(text string, model string, check bool) (int, error) {
|
||||
var err error
|
||||
if check {
|
||||
contains, words := SensitiveWordContains(text)
|
||||
if contains {
|
||||
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
|
||||
}
|
||||
}
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
return getTokenNum(tokenEncoder, text), err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user