mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-07 14:53:40 +08:00
refactor: 重构敏感词
This commit is contained in:
@@ -1,5 +1,13 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
goahocorasick "github.com/anknown/ahocorasick"
|
||||||
|
"one-api/constant"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
func SundaySearch(text string, pattern string) bool {
|
func SundaySearch(text string, pattern string) bool {
|
||||||
// 计算偏移表
|
// 计算偏移表
|
||||||
offset := make(map[rune]int)
|
offset := make(map[rune]int)
|
||||||
@@ -48,3 +56,25 @@ func RemoveDuplicate(s []string) []string {
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InitAc() *goahocorasick.Machine {
|
||||||
|
m := new(goahocorasick.Machine)
|
||||||
|
dict := readRunes()
|
||||||
|
if err := m.Build(dict); err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func readRunes() [][]rune {
|
||||||
|
var dict [][]rune
|
||||||
|
|
||||||
|
for _, word := range constant.SensitiveWords {
|
||||||
|
word = strings.ToLower(word)
|
||||||
|
l := bytes.TrimSpace([]byte(word))
|
||||||
|
dict = append(dict, bytes.Runes(l))
|
||||||
|
}
|
||||||
|
|
||||||
|
return dict
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ var StreamCacheQueueLength = 0
|
|||||||
// SensitiveWords 敏感词
|
// SensitiveWords 敏感词
|
||||||
// var SensitiveWords []string
|
// var SensitiveWords []string
|
||||||
var SensitiveWords = []string{
|
var SensitiveWords = []string{
|
||||||
"test",
|
"test_sensitive",
|
||||||
}
|
}
|
||||||
|
|
||||||
func SensitiveWordsToString() string {
|
func SensitiveWordsToString() string {
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
|
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -256,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||||
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
|
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
if simpleResponse.Usage.TotalTokens == 0 {
|
if simpleResponse.Usage.TotalTokens == 0 {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
|
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
|
||||||
completionTokens += ctkm
|
completionTokens += ctkm
|
||||||
}
|
}
|
||||||
simpleResponse.Usage = dto.Usage{
|
simpleResponse.Usage = dto.Usage{
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
|
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
@@ -55,7 +55,13 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||||
promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
|
if constant.ShouldCheckPromptSensitive() {
|
||||||
|
err = service.CheckSensitiveInput(audioRequest.Input)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -178,7 +184,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||||
quota = promptTokens
|
quota = promptTokens
|
||||||
} else {
|
} else {
|
||||||
quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false)
|
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
|
||||||
}
|
}
|
||||||
quota = int(float64(quota) * ratio)
|
quota = int(float64(quota) * ratio)
|
||||||
if ratio != 0 && quota <= 0 {
|
if ratio != 0 && quota <= 0 {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -47,6 +48,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constant.ShouldCheckPromptSensitive() {
|
||||||
|
err = service.CheckSensitiveInput(imageRequest.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if strings.Contains(imageRequest.Size, "×") {
|
if strings.Contains(imageRequest.Size, "×") {
|
||||||
return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,13 +98,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
var ratio float64
|
var ratio float64
|
||||||
var modelRatio float64
|
var modelRatio float64
|
||||||
//err := service.SensitiveWordsCheck(textRequest)
|
//err := service.SensitiveWordsCheck(textRequest)
|
||||||
promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo)
|
|
||||||
|
|
||||||
// count messages token error 计算promptTokens错误
|
if constant.ShouldCheckPromptSensitive() {
|
||||||
|
err = checkRequestSensitive(textRequest, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if sensitiveTrigger {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptTokens, err := getPromptTokens(textRequest, relayInfo)
|
||||||
|
// count messages token error 计算promptTokens错误
|
||||||
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,7 +132,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo, *textRequest)
|
adaptor.Init(relayInfo, *textRequest)
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
@@ -136,7 +140,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if isModelMapped {
|
if isModelMapped {
|
||||||
jsonStr, err := json.Marshal(textRequest)
|
jsonStr, err := json.Marshal(textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
} else {
|
} else {
|
||||||
@@ -145,11 +149,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
} else {
|
} else {
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
}
|
}
|
||||||
@@ -182,26 +186,39 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) {
|
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
var err error
|
var err error
|
||||||
var sensitiveTrigger bool
|
|
||||||
checkSensitive := constant.ShouldCheckPromptSensitive()
|
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
|
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
|
||||||
case relayconstant.RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||||
case relayconstant.RelayModeModerations:
|
case relayconstant.RelayModeModerations:
|
||||||
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
|
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
|
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
default:
|
default:
|
||||||
err = errors.New("unknown relay mode")
|
err = errors.New("unknown relay mode")
|
||||||
promptTokens = 0
|
promptTokens = 0
|
||||||
}
|
}
|
||||||
info.PromptTokens = promptTokens
|
info.PromptTokens = promptTokens
|
||||||
return promptTokens, err, sensitiveTrigger
|
return promptTokens, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
|
||||||
|
var err error
|
||||||
|
switch info.RelayMode {
|
||||||
|
case relayconstant.RelayModeChatCompletions:
|
||||||
|
err = service.CheckSensitiveMessages(textRequest.Messages)
|
||||||
|
case relayconstant.RelayModeCompletions:
|
||||||
|
err = service.CheckSensitiveInput(textRequest.Prompt)
|
||||||
|
case relayconstant.RelayModeModerations:
|
||||||
|
err = service.CheckSensitiveInput(textRequest.Input)
|
||||||
|
case relayconstant.RelayModeEmbeddings:
|
||||||
|
err = service.CheckSensitiveInput(textRequest.Input)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 预扣费并返回用户剩余配额
|
// 预扣费并返回用户剩余配额
|
||||||
|
|||||||
@@ -1,13 +1,60 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/anknown/ahocorasick"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func CheckSensitiveMessages(messages []dto.Message) error {
|
||||||
|
for _, message := range messages {
|
||||||
|
if len(message.Content) > 0 {
|
||||||
|
if message.IsStringContent() {
|
||||||
|
stringContent := message.StringContent()
|
||||||
|
if ok, words := SensitiveWordContains(stringContent); ok {
|
||||||
|
return errors.New("sensitive words: " + strings.Join(words, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
arrayContent := message.ParseContent()
|
||||||
|
for _, m := range arrayContent {
|
||||||
|
if m.Type == "image_url" {
|
||||||
|
// TODO: check image url
|
||||||
|
} else {
|
||||||
|
if ok, words := SensitiveWordContains(m.Text); ok {
|
||||||
|
return errors.New("sensitive words: " + strings.Join(words, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckSensitiveText(text string) error {
|
||||||
|
if ok, words := SensitiveWordContains(text); ok {
|
||||||
|
return errors.New("sensitive words: " + strings.Join(words, ","))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckSensitiveInput(input any) error {
|
||||||
|
switch v := input.(type) {
|
||||||
|
case string:
|
||||||
|
return CheckSensitiveText(v)
|
||||||
|
case []string:
|
||||||
|
text := ""
|
||||||
|
for _, s := range v {
|
||||||
|
text += s
|
||||||
|
}
|
||||||
|
return CheckSensitiveText(text)
|
||||||
|
}
|
||||||
|
return CheckSensitiveText(fmt.Sprintf("%v", input))
|
||||||
|
}
|
||||||
|
|
||||||
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
|
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
|
||||||
func SensitiveWordContains(text string) (bool, []string) {
|
func SensitiveWordContains(text string) (bool, []string) {
|
||||||
if len(constant.SensitiveWords) == 0 {
|
if len(constant.SensitiveWords) == 0 {
|
||||||
@@ -15,7 +62,7 @@ func SensitiveWordContains(text string) (bool, []string) {
|
|||||||
}
|
}
|
||||||
checkText := strings.ToLower(text)
|
checkText := strings.ToLower(text)
|
||||||
// 构建一个AC自动机
|
// 构建一个AC自动机
|
||||||
m := initAc()
|
m := common.InitAc()
|
||||||
hits := m.MultiPatternSearch([]rune(checkText), false)
|
hits := m.MultiPatternSearch([]rune(checkText), false)
|
||||||
if len(hits) > 0 {
|
if len(hits) > 0 {
|
||||||
words := make([]string, 0)
|
words := make([]string, 0)
|
||||||
@@ -33,7 +80,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
|
|||||||
return false, nil, text
|
return false, nil, text
|
||||||
}
|
}
|
||||||
checkText := strings.ToLower(text)
|
checkText := strings.ToLower(text)
|
||||||
m := initAc()
|
m := common.InitAc()
|
||||||
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
|
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
|
||||||
if len(hits) > 0 {
|
if len(hits) > 0 {
|
||||||
words := make([]string, 0)
|
words := make([]string, 0)
|
||||||
@@ -47,25 +94,3 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
|
|||||||
}
|
}
|
||||||
return false, nil, text
|
return false, nil, text
|
||||||
}
|
}
|
||||||
|
|
||||||
func initAc() *goahocorasick.Machine {
|
|
||||||
m := new(goahocorasick.Machine)
|
|
||||||
dict := readRunes()
|
|
||||||
if err := m.Build(dict); err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func readRunes() [][]rune {
|
|
||||||
var dict [][]rune
|
|
||||||
|
|
||||||
for _, word := range constant.SensitiveWords {
|
|
||||||
word = strings.ToLower(word)
|
|
||||||
l := bytes.TrimSpace([]byte(word))
|
|
||||||
dict = append(dict, bytes.Runes(l))
|
|
||||||
}
|
|
||||||
|
|
||||||
return dict
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -125,11 +125,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
|||||||
return tiles*170 + 85, nil
|
return tiles*170 + 85, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
|
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
|
||||||
tkm := 0
|
tkm := 0
|
||||||
msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
|
msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err, b
|
return 0, err
|
||||||
}
|
}
|
||||||
tkm += msgTokens
|
tkm += msgTokens
|
||||||
if request.Tools != nil {
|
if request.Tools != nil {
|
||||||
@@ -137,7 +137,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
|
|||||||
var openaiTools []dto.OpenAITools
|
var openaiTools []dto.OpenAITools
|
||||||
err := json.Unmarshal(toolsData, &openaiTools)
|
err := json.Unmarshal(toolsData, &openaiTools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
|
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
|
||||||
}
|
}
|
||||||
countStr := ""
|
countStr := ""
|
||||||
for _, tool := range openaiTools {
|
for _, tool := range openaiTools {
|
||||||
@@ -149,18 +149,18 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
|
|||||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
toolTokens, err, _ := CountTokenInput(countStr, model, false)
|
toolTokens, err := CountTokenInput(countStr, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err, false
|
return 0, err
|
||||||
}
|
}
|
||||||
tkm += 8
|
tkm += 8
|
||||||
tkm += toolTokens
|
tkm += toolTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
return tkm, nil, false
|
return tkm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
|
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
||||||
//recover when panic
|
//recover when panic
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
// Reference:
|
// Reference:
|
||||||
@@ -184,13 +184,6 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
|||||||
if len(message.Content) > 0 {
|
if len(message.Content) > 0 {
|
||||||
if message.IsStringContent() {
|
if message.IsStringContent() {
|
||||||
stringContent := message.StringContent()
|
stringContent := message.StringContent()
|
||||||
if checkSensitive {
|
|
||||||
contains, words := SensitiveWordContains(stringContent)
|
|
||||||
if contains {
|
|
||||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
|
||||||
return 0, err, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||||
if message.Name != nil {
|
if message.Name != nil {
|
||||||
tokenNum += tokensPerName
|
tokenNum += tokensPerName
|
||||||
@@ -203,7 +196,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
|||||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||||
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err, false
|
return 0, err
|
||||||
}
|
}
|
||||||
tokenNum += imageTokenNum
|
tokenNum += imageTokenNum
|
||||||
log.Printf("image token num: %d", imageTokenNum)
|
log.Printf("image token num: %d", imageTokenNum)
|
||||||
@@ -215,33 +208,33 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||||
return tokenNum, nil, false
|
return tokenNum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
func CountTokenInput(input any, model string) (int, error) {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
return CountTokenText(v, model, check)
|
return CountTokenText(v, model)
|
||||||
case []string:
|
case []string:
|
||||||
text := ""
|
text := ""
|
||||||
for _, s := range v {
|
for _, s := range v {
|
||||||
text += s
|
text += s
|
||||||
}
|
}
|
||||||
return CountTokenText(text, model, check)
|
return CountTokenText(text, model)
|
||||||
}
|
}
|
||||||
return CountTokenInput(fmt.Sprintf("%v", input), model, check)
|
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||||
tokens := 0
|
tokens := 0
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
|
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
if message.Delta.ToolCalls != nil {
|
if message.Delta.ToolCalls != nil {
|
||||||
for _, tool := range message.Delta.ToolCalls {
|
for _, tool := range message.Delta.ToolCalls {
|
||||||
tkm, _, _ := CountTokenInput(tool.Function.Name, model, false)
|
tkm, _ := CountTokenInput(tool.Function.Name, model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false)
|
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -249,29 +242,17 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
|||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
func CountAudioToken(text string, model string) (int, error) {
|
||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
contains, words := SensitiveWordContains(text)
|
return utf8.RuneCountInString(text), nil
|
||||||
if contains {
|
|
||||||
return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
|
|
||||||
}
|
|
||||||
return utf8.RuneCountInString(text), nil, false
|
|
||||||
} else {
|
} else {
|
||||||
return CountTokenText(text, model, check)
|
return CountTokenText(text, model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||||
func CountTokenText(text string, model string, check bool) (int, error, bool) {
|
func CountTokenText(text string, model string) (int, error) {
|
||||||
var err error
|
var err error
|
||||||
var trigger bool
|
|
||||||
if check {
|
|
||||||
contains, words := SensitiveWordContains(text)
|
|
||||||
if contains {
|
|
||||||
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
|
|
||||||
trigger = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
return getTokenNum(tokenEncoder, text), err, trigger
|
return getTokenNum(tokenEncoder, text), err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
ctkm, err, _ := CountTokenText(responseText, modeName, false)
|
ctkm, err := CountTokenText(responseText, modeName)
|
||||||
usage.CompletionTokens = ctkm
|
usage.CompletionTokens = ctkm
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage, err
|
return usage, err
|
||||||
|
|||||||
Reference in New Issue
Block a user