mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-13 01:23:41 +08:00
refactor: 重构敏感词
This commit is contained in:
@@ -125,11 +125,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
return tiles*170 + 85, nil
|
||||
}
|
||||
|
||||
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
|
||||
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
|
||||
tkm := 0
|
||||
msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
|
||||
msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err, b
|
||||
return 0, err
|
||||
}
|
||||
tkm += msgTokens
|
||||
if request.Tools != nil {
|
||||
@@ -137,7 +137,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
|
||||
var openaiTools []dto.OpenAITools
|
||||
err := json.Unmarshal(toolsData, &openaiTools)
|
||||
if err != nil {
|
||||
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
|
||||
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
|
||||
}
|
||||
countStr := ""
|
||||
for _, tool := range openaiTools {
|
||||
@@ -149,18 +149,18 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens, err, _ := CountTokenInput(countStr, model, false)
|
||||
toolTokens, err := CountTokenInput(countStr, model)
|
||||
if err != nil {
|
||||
return 0, err, false
|
||||
return 0, err
|
||||
}
|
||||
tkm += 8
|
||||
tkm += toolTokens
|
||||
}
|
||||
|
||||
return tkm, nil, false
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
|
||||
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
@@ -184,13 +184,6 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
||||
if len(message.Content) > 0 {
|
||||
if message.IsStringContent() {
|
||||
stringContent := message.StringContent()
|
||||
if checkSensitive {
|
||||
contains, words := SensitiveWordContains(stringContent)
|
||||
if contains {
|
||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
||||
return 0, err, true
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
@@ -203,7 +196,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
||||
if err != nil {
|
||||
return 0, err, false
|
||||
return 0, err
|
||||
}
|
||||
tokenNum += imageTokenNum
|
||||
log.Printf("image token num: %d", imageTokenNum)
|
||||
@@ -215,33 +208,33 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum, nil, false
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
||||
func CountTokenInput(input any, model string) (int, error) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTokenText(v, model, check)
|
||||
return CountTokenText(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return CountTokenText(text, model, check)
|
||||
return CountTokenText(text, model)
|
||||
}
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model, check)
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||
}
|
||||
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
|
||||
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
tkm, _, _ := CountTokenInput(tool.Function.Name, model, false)
|
||||
tkm, _ := CountTokenInput(tool.Function.Name, model)
|
||||
tokens += tkm
|
||||
tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false)
|
||||
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
||||
tokens += tkm
|
||||
}
|
||||
}
|
||||
@@ -249,29 +242,17 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
||||
func CountAudioToken(text string, model string) (int, error) {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
contains, words := SensitiveWordContains(text)
|
||||
if contains {
|
||||
return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
|
||||
}
|
||||
return utf8.RuneCountInString(text), nil, false
|
||||
return utf8.RuneCountInString(text), nil
|
||||
} else {
|
||||
return CountTokenText(text, model, check)
|
||||
return CountTokenText(text, model)
|
||||
}
|
||||
}
|
||||
|
||||
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTokenText(text string, model string, check bool) (int, error, bool) {
|
||||
func CountTokenText(text string, model string) (int, error) {
|
||||
var err error
|
||||
var trigger bool
|
||||
if check {
|
||||
contains, words := SensitiveWordContains(text)
|
||||
if contains {
|
||||
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
|
||||
trigger = true
|
||||
}
|
||||
}
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text), err, trigger
|
||||
return getTokenNum(tokenEncoder, text), err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user