feat: 初步兼容敏感词过滤

This commit is contained in:
CaIon
2024-03-20 17:07:42 +08:00
parent bec21ade9d
commit 7a663d26ec
22 changed files with 293 additions and 66 deletions

60
service/sensitive.go Normal file
View File

@@ -0,0 +1,60 @@
package service
import (
"bytes"
"fmt"
"github.com/anknown/ahocorasick"
"one-api/constant"
"strings"
)
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
func SensitiveWordContains(text string) (bool, []string) {
// 构建一个AC自动机
m := initAc()
hits := m.MultiPatternSearch([]rune(text), false)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
words = append(words, string(hit.Word))
}
return true, words
}
return false, nil
}
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
func SensitiveWordReplace(text string) (bool, string) {
m := initAc()
hits := m.MultiPatternSearch([]rune(text), false)
if len(hits) > 0 {
for _, hit := range hits {
pos := hit.Pos
word := string(hit.Word)
text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):]
}
return true, text
}
return false, text
}
func initAc() *goahocorasick.Machine {
m := new(goahocorasick.Machine)
dict := readRunes()
if err := m.Build(dict); err != nil {
fmt.Println(err)
return nil
}
return m
}
func readRunes() [][]rune {
var dict [][]rune
for _, word := range constant.SensitiveWords {
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))
}
return dict
}

View File

@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
return tiles*170 + 85, nil
}
func CountTokenMessages(messages []dto.Message, model string) (int, error) {
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -144,6 +144,13 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err
} else {
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
@@ -190,29 +197,37 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
return tokenNum, nil
}
func CountTokenInput(input any, model string) int {
func CountTokenInput(input any, model string, check bool) (int, error) {
switch v := input.(type) {
case string:
return CountTokenText(v, model)
return CountTokenText(v, model, check)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTokenText(text, model)
return CountTokenText(text, model, check)
}
return 0
return 0, errors.New("unsupported input type")
}
func CountAudioToken(text string, model string) int {
func CountAudioToken(text string, model string, check bool) (int, error) {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text)
return utf8.RuneCountInString(text), nil
} else {
return CountTokenText(text, model)
return CountTokenText(text, model, check)
}
}
func CountTokenText(text string, model string) int {
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTokenText(text string, model string, check bool) (int, error) {
var err error
if check {
contains, words := SensitiveWordContains(text)
if contains {
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
}
}
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
return getTokenNum(tokenEncoder, text), err
}

View File

@@ -1,27 +1,26 @@
package service
import (
"errors"
"one-api/dto"
"one-api/relay/constant"
)
func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
switch relayMode {
case constant.RelayModeChatCompletions:
return CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions:
return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
case constant.RelayModeModerations:
return CountTokenInput(textRequest.Input, textRequest.Model), nil
}
return 0, errors.New("unknown relay mode")
}
//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
// switch relayMode {
// case constant.RelayModeChatCompletions:
// return CountTokenMessages(textRequest.Messages, textRequest.Model)
// case constant.RelayModeCompletions:
// return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
// case constant.RelayModeModerations:
// return CountTokenInput(textRequest.Input, textRequest.Model), nil
// }
// return 0, errors.New("unknown relay mode")
//}
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
ctkm, err := CountTokenText(responseText, modeName, false)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
return usage, err
}