From 7a663d26ec2b30cedc7c62378ff2ad8fdb135561 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 17:07:42 +0800 Subject: [PATCH 01/11] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=E6=95=8F=E6=84=9F=E8=AF=8D=E8=BF=87=E6=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/str.go | 38 ++++++++++++++++ constant/sensitive.go | 32 ++++++++++++++ go.mod | 2 + go.sum | 20 +++------ model/option.go | 15 +++++++ relay/channel/claude/relay-claude.go | 14 +++--- relay/channel/gemini/adaptor.go | 2 +- relay/channel/gemini/relay-gemini.go | 3 +- relay/channel/ollama/adaptor.go | 2 +- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/relay-openai.go | 4 +- relay/channel/palm/adaptor.go | 2 +- relay/channel/palm/relay-palm.go | 3 +- relay/channel/perplexity/adaptor.go | 2 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/zhipu_4v/adaptor.go | 2 +- relay/relay-audio.go | 21 ++++++--- relay/relay-text.go | 12 +++--- service/sensitive.go | 60 ++++++++++++++++++++++++++ service/token_counter.go | 35 ++++++++++----- service/usage_helpr.go | 31 +++++++------ web/src/components/OperationSetting.js | 55 +++++++++++++++++++++++ 22 files changed, 293 insertions(+), 66 deletions(-) create mode 100644 common/str.go create mode 100644 constant/sensitive.go create mode 100644 service/sensitive.go diff --git a/common/str.go b/common/str.go new file mode 100644 index 0000000..d16f7a0 --- /dev/null +++ b/common/str.go @@ -0,0 +1,38 @@ +package common + +func SundaySearch(text string, pattern string) bool { + // 计算偏移表 + offset := make(map[rune]int) + for i, c := range pattern { + offset[c] = len(pattern) - i + } + + // 文本串长度和模式串长度 + n, m := len(text), len(pattern) + + // 主循环,i表示当前对齐的文本串位置 + for i := 0; i <= n-m; { + // 检查子串 + j := 0 + for j < m && text[i+j] == pattern[j] { + j++ + } + // 如果完全匹配,返回匹配位置 + if j == m { + return true + } + + // 如果还有剩余字符,则检查下一位字符在偏移表中的值 + if i+m < n { + next := rune(text[i+m]) + if val, ok := offset[next]; ok { + i += val // 存在于偏移表中,进行跳跃 + } else { + i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度 + } + } else { + break + } + } + return false // 如果没有找到匹配,返回-1 +} diff --git a/constant/sensitive.go b/constant/sensitive.go new file mode 100644 index 0000000..10ecfe6 --- /dev/null +++ b/constant/sensitive.go @@ -0,0 +1,32 @@ +package constant + +import "strings" + +var CheckSensitiveEnabled = true +var CheckSensitiveOnPromptEnabled = true +var CheckSensitiveOnCompletionEnabled = true + +// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词 +var StopOnSensitiveEnabled = true + +// SensitiveWords 敏感词 +// var SensitiveWords []string +var SensitiveWords = []string{ + "test", +} + +func SensitiveWordsToString() string { + return strings.Join(SensitiveWords, "\n") +} + +func SensitiveWordsFromString(s string) { + SensitiveWords = strings.Split(s, "\n") +} + +func ShouldCheckPromptSensitive() bool { + return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled +} + +func ShouldCheckCompletionSensitive() bool { + return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled +} diff --git a/go.mod b/go.mod index 5b877ad..b0c7220 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ module one-api go 1.18 require ( + github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -27,6 +28,7 @@ require ( ) require ( + github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect diff --git a/go.sum b/go.sum index 4ff383b..5b17b48 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= +github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= +github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= +github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -15,8 +19,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= @@ -37,7 +39,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -48,8 +49,6 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= -github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= -github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4= github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= @@ -105,8 +104,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= @@ -154,9 +151,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= @@ -175,8 +171,6 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= @@ -184,8 +178,6 @@ golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9 golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= @@ -200,8 +192,6 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/model/option.go b/model/option.go index a108e93..7422da1 100644 --- a/model/option.go +++ b/model/option.go @@ -90,6 +90,11 @@ func InitOptionMap() { common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled) + common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled) + common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled) + common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) + common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled) + common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString() common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() @@ -185,6 +190,14 @@ func updateOptionMap(key string, value string) (err error) { common.DefaultCollapseSidebar = boolValue case "MjNotifyEnabled": constant.MjNotifyEnabled = boolValue + case "CheckSensitiveEnabled": + constant.CheckSensitiveEnabled = boolValue + case "CheckSensitiveOnPromptEnabled": + constant.CheckSensitiveOnPromptEnabled = boolValue + case "CheckSensitiveOnCompletionEnabled": + constant.CheckSensitiveOnCompletionEnabled = boolValue + case "StopOnSensitiveEnabled": + constant.StopOnSensitiveEnabled = boolValue } } switch key { @@ -273,6 +286,8 @@ func updateOptionMap(key string, value string) (err error) { common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) case "QuotaPerUnit": common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + case "SensitiveWords": + constant.SensitiveWordsFromString(value) } return err } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 0e9aa42..1027faa 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/service" "strings" @@ -194,7 +195,7 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - var usage dto.Usage + var usage *dto.Usage responseText := "" createdTime := common.GetTimestamp() scanner := bufio.NewScanner(resp.Body) @@ -277,13 +278,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if requestMode == RequestModeCompletion { - usage = *service.ResponseText2Usage(responseText, modelName, promptTokens) + usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens) } else { if usage.CompletionTokens == 0 { - usage = *service.ResponseText2Usage(responseText, modelName, usage.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens) } } - return nil, &usage + return nil, usage } func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -312,7 +313,10 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT }, nil } fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens := service.CountTokenText(claudeResponse.Completion, model) + completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive()) + if err != nil { + return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil + } usage := dto.Usage{} if requestMode == RequestModeCompletion { usage.PromptTokens = promptTokens diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 0943ef0..a275175 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -51,7 +51,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = geminiChatStreamHandler(c, resp) - usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 94ba0c2..b199178 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -256,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model) + completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive()) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index d3fa5ef..55edf7a 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) - usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 92621d5..417dbce 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -75,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) - usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 9624606..a3a2634 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" "one-api/service" @@ -153,7 +154,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += service.CountTokenText(string(choice.Message.Content), model) + ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive()) + completionTokens += ctkm } textResponse.Usage = dto.Usage{ PromptTokens: promptTokens, diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 6458858..4f59a44 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) - usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index d775651..b3607c0 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -156,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model) + completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive()) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 4722bb7..24765ff 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) - usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 7571659..470ec14 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -57,7 +57,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) - usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = tencentHandler(c, resp) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 9007e33..1b8866b 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -48,7 +48,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) - usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 5d43e79..d68550e 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" @@ -62,8 +63,16 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest) } } - + var err error + promptTokens := 0 preConsumedTokens := common.PreConsumedQuota + if strings.HasPrefix(audioRequest.Model, "tts-1") { + promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive()) + if err != nil { + return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) + } + preConsumedTokens = promptTokens + } modelRatio := common.GetModelRatio(audioRequest.Model) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio @@ -161,12 +170,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { go func() { useTimeSeconds := time.Now().Unix() - startTime.Unix() quota := 0 - var promptTokens = 0 if strings.HasPrefix(audioRequest.Model, "tts-1") { - quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model) - promptTokens = quota + quota = promptTokens } else { - quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model) + quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive()) } quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { @@ -208,6 +215,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } + contains, words := service.SensitiveWordContains(audioResponse.Text) + if contains { + return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest) + } } resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) diff --git a/relay/relay-text.go b/relay/relay-text.go index c5f79b6..8a38a81 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -10,6 +10,7 @@ import ( "math" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" @@ -96,6 +97,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { var preConsumedQuota int var ratio float64 var modelRatio float64 + //err := service.SensitiveWordsCheck(textRequest) promptTokens, err := getPromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 @@ -172,16 +174,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { var promptTokens int var err error - + checkSensitive := constant.ShouldCheckPromptSensitive() switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model) + promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) case relayconstant.RelayModeCompletions: - promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil + promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) case relayconstant.RelayModeModerations: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil + promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) case relayconstant.RelayModeEmbeddings: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil + promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) default: err = errors.New("unknown relay mode") promptTokens = 0 diff --git a/service/sensitive.go b/service/sensitive.go new file mode 100644 index 0000000..6b77849 --- /dev/null +++ b/service/sensitive.go @@ -0,0 +1,60 @@ +package service + +import ( + "bytes" + "fmt" + "github.com/anknown/ahocorasick" + "one-api/constant" + "strings" +) + +// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 +func SensitiveWordContains(text string) (bool, []string) { + // 构建一个AC自动机 + m := initAc() + hits := m.MultiPatternSearch([]rune(text), false) + if len(hits) > 0 { + words := make([]string, 0) + for _, hit := range hits { + words = append(words, string(hit.Word)) + } + return true, words + } + return false, nil +} + +// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 +func SensitiveWordReplace(text string) (bool, string) { + m := initAc() + hits := m.MultiPatternSearch([]rune(text), false) + if len(hits) > 0 { + for _, hit := range hits { + pos := hit.Pos + word := string(hit.Word) + text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):] + } + return true, text + } + return false, text +} + +func initAc() *goahocorasick.Machine { + m := new(goahocorasick.Machine) + dict := readRunes() + if err := m.Build(dict); err != nil { + fmt.Println(err) + return nil + } + return m +} + +func readRunes() [][]rune { + var dict [][]rune + + for _, word := range constant.SensitiveWords { + l := bytes.TrimSpace([]byte(word)) + dict = append(dict, bytes.Runes(l)) + } + + return dict +} diff --git a/service/token_counter.go b/service/token_counter.go index a8e82ed..a04be59 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { return tiles*170 + 85, nil } -func CountTokenMessages(messages []dto.Message, model string) (int, error) { +func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) { //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: @@ -144,6 +144,13 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) { if err := json.Unmarshal(message.Content, &stringContent); err != nil { return 0, err } else { + if checkSensitive { + contains, words := SensitiveWordContains(stringContent) + if contains { + err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) + return 0, err + } + } tokenNum += getTokenNum(tokenEncoder, stringContent) if message.Name != nil { tokenNum += tokensPerName @@ -190,29 +197,37 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) { return tokenNum, nil } -func CountTokenInput(input any, model string) int { +func CountTokenInput(input any, model string, check bool) (int, error) { switch v := input.(type) { case string: - return CountTokenText(v, model) + return CountTokenText(v, model, check) case []string: text := "" for _, s := range v { text += s } - return CountTokenText(text, model) + return CountTokenText(text, model, check) } - return 0 + return 0, errors.New("unsupported input type") } -func CountAudioToken(text string, model string) int { +func CountAudioToken(text string, model string, check bool) (int, error) { if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text) + return utf8.RuneCountInString(text), nil } else { - return CountTokenText(text, model) + return CountTokenText(text, model, check) } } -func CountTokenText(text string, model string) int { +// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 +func CountTokenText(text string, model string, check bool) (int, error) { + var err error + if check { + contains, words := SensitiveWordContains(text) + if contains { + err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")) + } + } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text) + return getTokenNum(tokenEncoder, text), err } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index c1fcfb5..53a5c04 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -1,27 +1,26 @@ package service import ( - "errors" "one-api/dto" - "one-api/relay/constant" ) -func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) { - switch relayMode { - case constant.RelayModeChatCompletions: - return CountTokenMessages(textRequest.Messages, textRequest.Model) - case constant.RelayModeCompletions: - return CountTokenInput(textRequest.Prompt, textRequest.Model), nil - case constant.RelayModeModerations: - return CountTokenInput(textRequest.Input, textRequest.Model), nil - } - return 0, errors.New("unknown relay mode") -} +//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) { +// switch relayMode { +// case constant.RelayModeChatCompletions: +// return CountTokenMessages(textRequest.Messages, textRequest.Model) +// case constant.RelayModeCompletions: +// return CountTokenInput(textRequest.Prompt, textRequest.Model), nil +// case constant.RelayModeModerations: +// return CountTokenInput(textRequest.Input, textRequest.Model), nil +// } +// return 0, errors.New("unknown relay mode") +//} -func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage { +func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { usage := &dto.Usage{} usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modeName) + ctkm, err := CountTokenText(responseText, modeName, false) + usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage + return usage, err } diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 874c552..c006106 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -23,6 +23,11 @@ const OperationSetting = () => { LogConsumeEnabled: '', DisplayInCurrencyEnabled: '', DisplayTokenStatEnabled: '', + CheckSensitiveEnabled: '', + CheckSensitiveOnPromptEnabled: '', + CheckSensitiveOnCompletionEnabled: '', + StopOnSensitiveEnabled: '', + SensitiveWords: '', MjNotifyEnabled: '', DrawingEnabled: '', DataExportEnabled: '', @@ -130,6 +135,11 @@ const OperationSetting = () => { await updateOption('ModelPrice', inputs.ModelPrice); } break; + case 'words': + if (originInputs['SensitiveWords'] !== inputs.SensitiveWords) { + await updateOption('SensitiveWords', inputs.SensitiveWords); + } + break; case 'quota': if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); @@ -273,6 +283,51 @@ const OperationSetting = () => { /> +
+ 屏蔽词过滤设置 +
+ + + + + + + + + + + + { + submitConfig('words').then(); + }}>保存屏蔽词设置 +
日志设置
From 64b9d3b58c03a35d2b616c8a7e207f1b9fe807a1 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 19:00:51 +0800 Subject: [PATCH 02/11] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=E7=94=9F=E6=88=90=E5=86=85=E5=AE=B9=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/str.go | 12 ++++ controller/channel-test.go | 2 +- dto/sensitive.go | 6 ++ dto/text_response.go | 4 +- relay/channel/adapter.go | 2 +- relay/channel/ali/adaptor.go | 2 +- relay/channel/baidu/adaptor.go | 2 +- relay/channel/claude/adaptor.go | 2 +- relay/channel/gemini/adaptor.go | 2 +- relay/channel/ollama/adaptor.go | 4 +- relay/channel/openai/adaptor.go | 4 +- relay/channel/openai/relay-openai.go | 101 +++++++++++++++++++-------- relay/channel/palm/adaptor.go | 2 +- relay/channel/perplexity/adaptor.go | 4 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/xunfei/adaptor.go | 6 +- relay/channel/zhipu/adaptor.go | 2 +- relay/channel/zhipu_4v/adaptor.go | 4 +- relay/common/relay_utils.go | 2 +- relay/relay-text.go | 25 +++++-- service/sensitive.go | 14 ++-- 21 files changed, 141 insertions(+), 63 deletions(-) create mode 100644 dto/sensitive.go diff --git a/common/str.go b/common/str.go index d16f7a0..ddf8375 100644 --- a/common/str.go +++ b/common/str.go @@ -36,3 +36,15 @@ func SundaySearch(text string, pattern string) bool { } return false // 如果没有找到匹配,返回-1 } + +func RemoveDuplicate(s []string) []string { + result := make([]string, 0, len(s)) + temp := map[string]struct{}{} + for _, item := range s { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + result = append(result, item) + } + } + return result +} diff --git a/controller/channel-test.go b/controller/channel-test.go index 4a2906b..2eee554 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr err := relaycommon.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } - usage, respErr := adaptor.DoResponse(c, resp, meta) + usage, respErr, _ := adaptor.DoResponse(c, resp, meta) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error } diff --git a/dto/sensitive.go b/dto/sensitive.go new file mode 100644 index 0000000..0bfbc6f --- /dev/null +++ b/dto/sensitive.go @@ -0,0 +1,6 @@ +package dto + +type SensitiveResponse struct { + SensitiveWords []string `json:"sensitive_words"` + Content string `json:"content"` +} diff --git a/dto/text_response.go b/dto/text_response.go index 81d0748..7b94ed5 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -1,9 +1,9 @@ package dto type TextResponse struct { - Choices []OpenAITextResponseChoice `json:"choices"` + Choices []*OpenAITextResponseChoice `json:"choices"` Usage `json:"usage"` - Error OpenAIError `json:"error"` + Error *OpenAIError `json:"error,omitempty"` } type OpenAITextResponseChoice struct { diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index d3886d5..935a0f0 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -15,7 +15,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) GetModelList() []string GetChannelName() string } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index bfe83db..155d5a0 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { err, usage = aliStreamHandler(c, resp) } else { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index d2571dc..d8c96aa 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { err, usage = baiduStreamHandler(c, resp) } else { diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 45efd01..ee5a4c0 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp) } else { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index a275175..5f78829 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = geminiChatStreamHandler(c, resp) diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 55edf7a..6ef2f30 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -39,13 +39,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 417dbce..d9c52f8 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -71,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index a3a2634..4409555 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -4,8 +4,11 @@ import ( "bufio" "bytes" "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" "io" + "log" "net/http" "one-api/common" "one-api/constant" @@ -18,6 +21,7 @@ import ( ) func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { + checkSensitive := constant.ShouldCheckCompletionSensitive() var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -37,11 +41,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d defer close(stopChan) defer close(dataChan) var wg sync.WaitGroup - go func() { wg.Add(1) defer wg.Done() - var streamItems []string + var streamItems []string // store stream items for scanner.Scan() { data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format @@ -50,11 +53,20 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d if data[:6] != "data: " && data[:6] != "[DONE]" { continue } + sensitive := false + if checkSensitive { + // check sensitive + sensitive, _, data = service.SensitiveWordReplace(data, constant.StopOnSensitiveEnabled) + } dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { streamItems = append(streamItems, data) } + if sensitive && constant.StopOnSensitiveEnabled { + dataChan <- "data: [DONE]" + break + } } streamResp := "[" + strings.Join(streamItems, ",") + "]" switch relayMode { @@ -112,50 +124,48 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d return nil, responseTextBuilder.String() } -func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { var textResponse dto.TextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil } - if textResponse.Error.Type != "" { + log.Printf("textResponse: %+v", textResponse) + if textResponse.Error != nil { return &dto.OpenAIErrorWithStatusCode{ - Error: textResponse.Error, + Error: *textResponse.Error, StatusCode: resp.StatusCode, - }, nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + }, nil, nil } - if textResponse.Usage.TotalTokens == 0 { + checkSensitive := constant.ShouldCheckCompletionSensitive() + sensitiveWords := make([]string, 0) + triggerSensitive := false + + if textResponse.Usage.TotalTokens == 0 || checkSensitive { completionTokens := 0 for _, choice := range textResponse.Choices { - ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive()) + stringContent := string(choice.Message.Content) + ctkm, _ := service.CountTokenText(stringContent, model, false) completionTokens += ctkm + if checkSensitive { + sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) + if sensitive { + triggerSensitive = true + msg := choice.Message + msg.Content = common.StringToByteSlice(stringContent) + choice.Message = msg + sensitiveWords = append(sensitiveWords, words...) + } + } } textResponse.Usage = dto.Usage{ PromptTokens: promptTokens, @@ -163,5 +173,36 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model TotalTokens: promptTokens + completionTokens, } } - return nil, &textResponse.Usage + + if constant.StopOnSensitiveEnabled { + + } else { + responseBody, err = json.Marshal(textResponse) + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + } + } + + if checkSensitive && triggerSensitive { + sensitiveWords = common.RemoveDuplicate(sensitiveWords) + return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{ + SensitiveWords: sensitiveWords, + } + } + return nil, &textResponse.Usage, nil } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 4f59a44..2129ee3 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 24765ff..8d056ec 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 470ec14..fa3da0c 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 79a4b12..9ab2818 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { splits := strings.Split(info.ApiKey, "|") if len(splits) != 3 { - return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil } if a.request == nil { - return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) + return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil } if info.IsStream { err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 6f2d186..69c45a8 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { err, usage = zhipuStreamHandler(c, resp) } else { diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 1b8866b..dded3c5 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index f89b8b9..71b8ba7 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open if err != nil { return } - OpenAIErrorWithStatusCode.Error = textResponse.Error + OpenAIErrorWithStatusCode.Error = *textResponse.Error return } diff --git a/relay/relay-text.go b/relay/relay-text.go index 8a38a81..b9c1e7a 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -162,12 +162,21 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.RelayErrorHandler(resp) } - usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo) if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - return openaiErr + if sensitiveResp == nil { // 如果没有敏感词检查结果 + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + return openaiErr + } else { + // 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额 + postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp) + if constant.StopOnSensitiveEnabled { // 是否直接返回错误 + return openaiErr + } + return nil + } } - postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) + postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil) return nil } @@ -243,7 +252,10 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu } } -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, + usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, + modelPrice float64, sensitiveResp *dto.SensitiveResponse) { + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens @@ -277,6 +289,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe logContent += fmt.Sprintf("(可能是上游超时)") common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota)) } else { + if sensitiveResp != nil { + logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { diff --git a/service/sensitive.go b/service/sensitive.go index 6b77849..57c667f 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -24,18 +24,21 @@ func SensitiveWordContains(text string) (bool, []string) { } // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 -func SensitiveWordReplace(text string) (bool, string) { +func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) { + text = strings.ToLower(text) m := initAc() - hits := m.MultiPatternSearch([]rune(text), false) + hits := m.MultiPatternSearch([]rune(text), returnImmediately) if len(hits) > 0 { + words := make([]string, 0) for _, hit := range hits { pos := hit.Pos word := string(hit.Word) - text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):] + text = text[:pos] + " *###* " + text[pos+len(word):] + words = append(words, word) } - return true, text + return true, words, text } - return false, text + return false, nil, text } func initAc() *goahocorasick.Machine { @@ -52,6 +55,7 @@ func readRunes() [][]rune { var dict [][]rune for _, word := range constant.SensitiveWords { + word = strings.ToLower(word) l := bytes.TrimSpace([]byte(word)) dict = append(dict, bytes.Runes(l)) } From 2db4282666dcd46fdb2ed7543617a65595fa2c3e Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 20:15:32 +0800 Subject: [PATCH 03/11] =?UTF-8?q?feat:=20=E4=BF=9D=E7=95=99=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- constant/sensitive.go | 3 +++ model/option.go | 3 +++ web/src/components/OperationSetting.js | 14 ++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/constant/sensitive.go b/constant/sensitive.go index 10ecfe6..fabca67 100644 --- a/constant/sensitive.go +++ b/constant/sensitive.go @@ -9,6 +9,9 @@ var CheckSensitiveOnCompletionEnabled = true // StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词 var StopOnSensitiveEnabled = true +// StreamCacheQueueLength 流模式缓存队列长度,0表示无缓存 +var StreamCacheQueueLength = 0 + // SensitiveWords 敏感词 // var SensitiveWords []string var SensitiveWords = []string{ diff --git a/model/option.go b/model/option.go index 7422da1..46e41da 100644 --- a/model/option.go +++ b/model/option.go @@ -95,6 +95,7 @@ func InitOptionMap() { common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled) common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString() + common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength) common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() @@ -288,6 +289,8 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) case "SensitiveWords": constant.SensitiveWordsFromString(value) + case "StreamCacheQueueLength": + constant.StreamCacheQueueLength, _ = strconv.Atoi(value) } return err } diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index c006106..728eab0 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -10,6 +10,7 @@ const OperationSetting = () => { QuotaForInvitee: 0, QuotaRemindThreshold: 0, PreConsumedQuota: 0, + StreamCacheQueueLength: 0, ModelRatio: '', ModelPrice: '', GroupRatio: '', @@ -307,6 +308,8 @@ const OperationSetting = () => { name="CheckSensitiveOnCompletionEnabled" onChange={handleInputChange} /> + + { onChange={handleInputChange} /> + {/**/} + {/* */} + {/**/} Date: Wed, 20 Mar 2024 20:26:34 +0800 Subject: [PATCH 04/11] fix: fix error --- relay/channel/openai/relay-openai.go | 2 +- service/sensitive.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4409555..9ec4260 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -56,7 +56,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d sensitive := false if checkSensitive { // check sensitive - sensitive, _, data = service.SensitiveWordReplace(data, constant.StopOnSensitiveEnabled) + sensitive, _, data = service.SensitiveWordReplace(data, false) } dataChan <- data data = data[6:] diff --git a/service/sensitive.go b/service/sensitive.go index 57c667f..28618a0 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -25,15 +25,15 @@ func SensitiveWordContains(text string) (bool, []string) { // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) { - text = strings.ToLower(text) + checkText := strings.ToLower(text) m := initAc() - hits := m.MultiPatternSearch([]rune(text), returnImmediately) + hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) if len(hits) > 0 { words := make([]string, 0) for _, hit := range hits { pos := hit.Pos word := string(hit.Word) - text = text[:pos] + " *###* " + text[pos+len(word):] + text = text[:pos] + "*###*" + text[pos+len(word):] words = append(words, word) } return true, words, text From eb6257a8d892c9a6ee711890d89e821afa725b59 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 20:28:00 +0800 Subject: [PATCH 05/11] fix: fix SensitiveWordContains not working --- service/sensitive.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/service/sensitive.go b/service/sensitive.go index 28618a0..b216376 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -10,9 +10,10 @@ import ( // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 func SensitiveWordContains(text string) (bool, []string) { + checkText := strings.ToLower(text) // 构建一个AC自动机 m := initAc() - hits := m.MultiPatternSearch([]rune(text), false) + hits := m.MultiPatternSearch([]rune(checkText), false) if len(hits) > 0 { words := make([]string, 0) for _, hit := range hits { From a232afe9fdeea02452f96c3a85194344cdc29270 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 20:36:55 +0800 Subject: [PATCH 06/11] =?UTF-8?q?feat:=20=E7=BB=9F=E4=B8=80=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/claude/relay-claude.go | 2 +- relay/channel/gemini/relay-gemini.go | 2 +- relay/channel/openai/relay-openai.go | 2 +- relay/channel/palm/relay-palm.go | 2 +- relay/relay-audio.go | 4 ++-- relay/relay-text.go | 18 +++++++++++------- service/token_counter.go | 28 +++++++++++++++++----------- service/usage_helpr.go | 2 +- 8 files changed, 35 insertions(+), 25 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 1027faa..a56c1bb 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -313,7 +313,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT }, nil } fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive()) + completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive()) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b199178..31badd8 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -257,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive()) + completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive()) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 9ec4260..7e36861 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -154,7 +154,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model completionTokens := 0 for _, choice := range textResponse.Choices { stringContent := string(choice.Message.Content) - ctkm, _ := service.CountTokenText(stringContent, model, false) + ctkm, _, _ := service.CountTokenText(stringContent, model, false) completionTokens += ctkm if checkSensitive { sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index b3607c0..4028269 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -157,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive()) + completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive()) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/relay-audio.go b/relay/relay-audio.go index d68550e..1c0f868 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -67,7 +67,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if strings.HasPrefix(audioRequest.Model, "tts-1") { - promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive()) + promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive()) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } @@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if strings.HasPrefix(audioRequest.Model, "tts-1") { quota = promptTokens } else { - quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive()) + quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive()) } quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { diff --git a/relay/relay-text.go b/relay/relay-text.go index b9c1e7a..a8ba0e8 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -98,10 +98,13 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { var ratio float64 var modelRatio float64 //err := service.SensitiveWordsCheck(textRequest) - promptTokens, err := getPromptTokens(textRequest, relayInfo) + promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { + if sensitiveTrigger { + return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) + } return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } @@ -180,25 +183,26 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return nil } -func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { +func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) { var promptTokens int var err error + var sensitiveTrigger bool checkSensitive := constant.ShouldCheckPromptSensitive() switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) case relayconstant.RelayModeCompletions: - promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) case relayconstant.RelayModeModerations: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) case relayconstant.RelayModeEmbeddings: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) default: err = errors.New("unknown relay mode") promptTokens = 0 } info.PromptTokens = promptTokens - return promptTokens, err + return promptTokens, err, sensitiveTrigger } // 预扣费并返回用户剩余配额 diff --git a/service/token_counter.go b/service/token_counter.go index a04be59..4769dab 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { return tiles*170 + 85, nil } -func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) { +func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) { //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: @@ -142,13 +142,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo if err := json.Unmarshal(message.Content, &arrayContent); err != nil { var stringContent string if err := json.Unmarshal(message.Content, &stringContent); err != nil { - return 0, err + return 0, err, false } else { if checkSensitive { contains, words := SensitiveWordContains(stringContent) if contains { err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) - return 0, err + return 0, err, true } } tokenNum += getTokenNum(tokenEncoder, stringContent) @@ -181,7 +181,7 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo imageTokenNum, err = getImageToken(&imageUrl) } if err != nil { - return 0, err + return 0, err, false } } tokenNum += imageTokenNum @@ -194,10 +194,10 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo } } tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum, nil + return tokenNum, nil, false } -func CountTokenInput(input any, model string, check bool) (int, error) { +func CountTokenInput(input any, model string, check bool) (int, error, bool) { switch v := input.(type) { case string: return CountTokenText(v, model, check) @@ -208,26 +208,32 @@ func CountTokenInput(input any, model string, check bool) (int, error) { } return CountTokenText(text, model, check) } - return 0, errors.New("unsupported input type") + return 0, errors.New("unsupported input type"), false } -func CountAudioToken(text string, model string, check bool) (int, error) { +func CountAudioToken(text string, model string, check bool) (int, error, bool) { if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text), nil + contains, words := SensitiveWordContains(text) + if contains { + return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true + } + return utf8.RuneCountInString(text), nil, false } else { return CountTokenText(text, model, check) } } // CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTokenText(text string, model string, check bool) (int, error) { +func CountTokenText(text string, model string, check bool) (int, error, bool) { var err error + var trigger bool if check { contains, words := SensitiveWordContains(text) if contains { err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")) + trigger = true } } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text), err + return getTokenNum(tokenEncoder, text), err, trigger } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index 53a5c04..460ac56 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -19,7 +19,7 @@ import ( func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTokenText(responseText, modeName, false) + ctkm, err, _ := CountTokenText(responseText, modeName, false) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage, err From b736de71896fe4d01e7e0ff019b6a79379c7ae1c Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 21:28:45 +0800 Subject: [PATCH 07/11] fix: claude panic --- relay/channel/claude/relay-claude.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index a56c1bb..6c89f34 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -150,6 +150,9 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* claudeUsage = &claudeResponse.Usage } } + if claudeUsage == nil { + claudeUsage = &ClaudeUsage{} + } response.Choices = append(response.Choices, choice) return &response, claudeUsage } From dd71946047e9bd31e6947609ea98292c683af5c1 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 21:32:33 +0800 Subject: [PATCH 08/11] fix: claude panic --- relay/channel/claude/relay-claude.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 6c89f34..f1a3c39 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -199,6 +199,7 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) var usage *dto.Usage + usage = &dto.Usage{} responseText := "" createdTime := common.GetTimestamp() scanner := bufio.NewScanner(resp.Body) From d5e93e788d11a0a02221243113efeba7a8fd9def Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 21:49:54 +0800 Subject: [PATCH 09/11] fix: midjourneys table --- model/main.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/model/main.go b/model/main.go index 62a02a2..7b1cd3d 100644 --- a/model/main.go +++ b/model/main.go @@ -94,7 +94,10 @@ func InitDB() (err error) { return nil } if common.UsingMySQL { - _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded + _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded + _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded } common.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) From c7658b70d1202a10cb673a205a21a35287408e7d Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 22:33:22 +0800 Subject: [PATCH 10/11] =?UTF-8?q?fix:=20=E6=95=8F=E6=84=9F=E8=AF=8D?= =?UTF-8?q?=E5=BA=93=E4=B8=BA=E7=A9=BA=E6=97=B6=EF=BC=8C=E4=B8=8D=E6=A3=80?= =?UTF-8?q?=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- constant/sensitive.go | 8 +++++++- service/sensitive.go | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/constant/sensitive.go b/constant/sensitive.go index fabca67..4a9a7e6 100644 --- a/constant/sensitive.go +++ b/constant/sensitive.go @@ -23,7 +23,13 @@ func SensitiveWordsToString() string { } func SensitiveWordsFromString(s string) { - SensitiveWords = strings.Split(s, "\n") + sw := strings.Split(s, "\n") + for _, w := range sw { + w = strings.TrimSpace(w) + if w != "" { + SensitiveWords = append(SensitiveWords, w) + } + } } func ShouldCheckPromptSensitive() bool { diff --git a/service/sensitive.go b/service/sensitive.go index b216376..dbb0887 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -10,6 +10,9 @@ import ( // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 func SensitiveWordContains(text string) (bool, []string) { + if len(constant.SensitiveWords) == 0 { + return false, nil + } checkText := strings.ToLower(text) // 构建一个AC自动机 m := initAc() @@ -26,6 +29,9 @@ func SensitiveWordContains(text string) (bool, []string) { // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) { + if len(constant.SensitiveWords) == 0 { + return false, nil, text + } checkText := strings.ToLower(text) m := initAc() hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) From d7e25e16046dea684982907be473c09a9c9020a4 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 23:58:42 +0800 Subject: [PATCH 11/11] fix: fix SensitiveWords load error --- constant/sensitive.go | 1 + 1 file changed, 1 insertion(+) diff --git a/constant/sensitive.go b/constant/sensitive.go index 4a9a7e6..8d8e15f 100644 --- a/constant/sensitive.go +++ b/constant/sensitive.go @@ -23,6 +23,7 @@ func SensitiveWordsToString() string { } func SensitiveWordsFromString(s string) { + SensitiveWords = []string{} sw := strings.Split(s, "\n") for _, w := range sw { w = strings.TrimSpace(w)