From 7a663d26ec2b30cedc7c62378ff2ad8fdb135561 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 17:07:42 +0800
Subject: [PATCH 01/11] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC?=
=?UTF-8?q?=E5=AE=B9=E6=95=8F=E6=84=9F=E8=AF=8D=E8=BF=87=E6=BB=A4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
common/str.go | 38 ++++++++++++++++
constant/sensitive.go | 32 ++++++++++++++
go.mod | 2 +
go.sum | 20 +++------
model/option.go | 15 +++++++
relay/channel/claude/relay-claude.go | 14 +++---
relay/channel/gemini/adaptor.go | 2 +-
relay/channel/gemini/relay-gemini.go | 3 +-
relay/channel/ollama/adaptor.go | 2 +-
relay/channel/openai/adaptor.go | 2 +-
relay/channel/openai/relay-openai.go | 4 +-
relay/channel/palm/adaptor.go | 2 +-
relay/channel/palm/relay-palm.go | 3 +-
relay/channel/perplexity/adaptor.go | 2 +-
relay/channel/tencent/adaptor.go | 2 +-
relay/channel/zhipu_4v/adaptor.go | 2 +-
relay/relay-audio.go | 21 ++++++---
relay/relay-text.go | 12 +++---
service/sensitive.go | 60 ++++++++++++++++++++++++++
service/token_counter.go | 35 ++++++++++-----
service/usage_helpr.go | 31 +++++++------
web/src/components/OperationSetting.js | 55 +++++++++++++++++++++++
22 files changed, 293 insertions(+), 66 deletions(-)
create mode 100644 common/str.go
create mode 100644 constant/sensitive.go
create mode 100644 service/sensitive.go
diff --git a/common/str.go b/common/str.go
new file mode 100644
index 0000000..d16f7a0
--- /dev/null
+++ b/common/str.go
@@ -0,0 +1,38 @@
+package common
+
+func SundaySearch(text string, pattern string) bool {
+ // 计算偏移表
+ offset := make(map[rune]int)
+ for i, c := range pattern {
+ offset[c] = len(pattern) - i
+ }
+
+ // 文本串长度和模式串长度
+ n, m := len(text), len(pattern)
+
+ // 主循环,i表示当前对齐的文本串位置
+ for i := 0; i <= n-m; {
+ // 检查子串
+ j := 0
+ for j < m && text[i+j] == pattern[j] {
+ j++
+ }
+ // 如果完全匹配,返回匹配位置
+ if j == m {
+ return true
+ }
+
+ // 如果还有剩余字符,则检查下一位字符在偏移表中的值
+ if i+m < n {
+ next := rune(text[i+m])
+ if val, ok := offset[next]; ok {
+ i += val // 存在于偏移表中,进行跳跃
+ } else {
+ i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度
+ }
+ } else {
+ break
+ }
+ }
+ return false // 如果没有找到匹配,返回-1
+}
diff --git a/constant/sensitive.go b/constant/sensitive.go
new file mode 100644
index 0000000..10ecfe6
--- /dev/null
+++ b/constant/sensitive.go
@@ -0,0 +1,32 @@
+package constant
+
+import "strings"
+
+var CheckSensitiveEnabled = true
+var CheckSensitiveOnPromptEnabled = true
+var CheckSensitiveOnCompletionEnabled = true
+
+// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
+var StopOnSensitiveEnabled = true
+
+// SensitiveWords 敏感词
+// var SensitiveWords []string
+var SensitiveWords = []string{
+ "test",
+}
+
+func SensitiveWordsToString() string {
+ return strings.Join(SensitiveWords, "\n")
+}
+
+func SensitiveWordsFromString(s string) {
+ SensitiveWords = strings.Split(s, "\n")
+}
+
+func ShouldCheckPromptSensitive() bool {
+ return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
+}
+
+func ShouldCheckCompletionSensitive() bool {
+ return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
+}
diff --git a/go.mod b/go.mod
index 5b877ad..b0c7220 100644
--- a/go.mod
+++ b/go.mod
@@ -4,6 +4,7 @@ module one-api
go 1.18
require (
+ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -27,6 +28,7 @@ require (
)
require (
+ github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
diff --git a/go.sum b/go.sum
index 4ff383b..5b17b48 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,7 @@
+github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
+github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
+github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
+github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -15,8 +19,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
-github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
-github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
@@ -37,7 +39,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
-github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -48,8 +49,6 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
-github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
-github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
@@ -105,8 +104,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
-github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
-github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
@@ -154,9 +151,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
-github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
-github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -175,8 +171,6 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
-golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
@@ -184,8 +178,6 @@ golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
-golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
@@ -200,8 +192,6 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
-golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
diff --git a/model/option.go b/model/option.go
index a108e93..7422da1 100644
--- a/model/option.go
+++ b/model/option.go
@@ -90,6 +90,11 @@ func InitOptionMap() {
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
+ common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
+ common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
+ common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
+ common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
+ common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -185,6 +190,14 @@ func updateOptionMap(key string, value string) (err error) {
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue
+ case "CheckSensitiveEnabled":
+ constant.CheckSensitiveEnabled = boolValue
+ case "CheckSensitiveOnPromptEnabled":
+ constant.CheckSensitiveOnPromptEnabled = boolValue
+ case "CheckSensitiveOnCompletionEnabled":
+ constant.CheckSensitiveOnCompletionEnabled = boolValue
+ case "StopOnSensitiveEnabled":
+ constant.StopOnSensitiveEnabled = boolValue
}
}
switch key {
@@ -273,6 +286,8 @@ func updateOptionMap(key string, value string) (err error) {
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
+ case "SensitiveWords":
+ constant.SensitiveWordsFromString(value)
}
return err
}
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index 0e9aa42..1027faa 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/service"
"strings"
@@ -194,7 +195,7 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
- var usage dto.Usage
+ var usage *dto.Usage
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
@@ -277,13 +278,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
- usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
} else {
if usage.CompletionTokens == 0 {
- usage = *service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
}
}
- return nil, &usage
+ return nil, usage
}
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -312,7 +313,10 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
}, nil
}
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
- completionTokens := service.CountTokenText(claudeResponse.Completion, model)
+ completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
+ }
usage := dto.Usage{}
if requestMode == RequestModeCompletion {
usage.PromptTokens = promptTokens
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index 0943ef0..a275175 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -51,7 +51,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index 94ba0c2..b199178 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -256,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
- completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model)
+ completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index d3fa5ef..55edf7a 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -43,7 +43,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index 92621d5..417dbce 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -75,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 9624606..a3a2634 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"one-api/service"
@@ -153,7 +154,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
- completionTokens += service.CountTokenText(string(choice.Message.Content), model)
+ ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive())
+ completionTokens += ctkm
}
textResponse.Usage = dto.Usage{
PromptTokens: promptTokens,
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
index 6458858..4f59a44 100644
--- a/relay/channel/palm/adaptor.go
+++ b/relay/channel/palm/adaptor.go
@@ -43,7 +43,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go
index d775651..b3607c0 100644
--- a/relay/channel/palm/relay-palm.go
+++ b/relay/channel/palm/relay-palm.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -156,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
- completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model)
+ completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go
index 4722bb7..24765ff 100644
--- a/relay/channel/perplexity/adaptor.go
+++ b/relay/channel/perplexity/adaptor.go
@@ -47,7 +47,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
index 7571659..470ec14 100644
--- a/relay/channel/tencent/adaptor.go
+++ b/relay/channel/tencent/adaptor.go
@@ -57,7 +57,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = tencentHandler(c, resp)
}
diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go
index 9007e33..1b8866b 100644
--- a/relay/channel/zhipu_4v/adaptor.go
+++ b/relay/channel/zhipu_4v/adaptor.go
@@ -48,7 +48,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/relay-audio.go b/relay/relay-audio.go
index 5d43e79..d68550e 100644
--- a/relay/relay-audio.go
+++ b/relay/relay-audio.go
@@ -10,6 +10,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -62,8 +63,16 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
}
}
-
+ var err error
+ promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
+ if strings.HasPrefix(audioRequest.Model, "tts-1") {
+ promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
+ }
+ preConsumedTokens = promptTokens
+ }
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
@@ -161,12 +170,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
go func() {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
quota := 0
- var promptTokens = 0
if strings.HasPrefix(audioRequest.Model, "tts-1") {
- quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
- promptTokens = quota
+ quota = promptTokens
} else {
- quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
+ quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
@@ -208,6 +215,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
+ contains, words := service.SensitiveWordContains(audioResponse.Text)
+ if contains {
+ return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest)
+ }
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
diff --git a/relay/relay-text.go b/relay/relay-text.go
index c5f79b6..8a38a81 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -10,6 +10,7 @@ import (
"math"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -96,6 +97,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
var preConsumedQuota int
var ratio float64
var modelRatio float64
+ //err := service.SensitiveWordsCheck(textRequest)
promptTokens, err := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
@@ -172,16 +174,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
-
+ checkSensitive := constant.ShouldCheckPromptSensitive()
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
- promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model)
+ promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
case relayconstant.RelayModeCompletions:
- promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
+ promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
case relayconstant.RelayModeModerations:
- promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
+ promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
case relayconstant.RelayModeEmbeddings:
- promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
+ promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
default:
err = errors.New("unknown relay mode")
promptTokens = 0
diff --git a/service/sensitive.go b/service/sensitive.go
new file mode 100644
index 0000000..6b77849
--- /dev/null
+++ b/service/sensitive.go
@@ -0,0 +1,60 @@
+package service
+
+import (
+ "bytes"
+ "fmt"
+ "github.com/anknown/ahocorasick"
+ "one-api/constant"
+ "strings"
+)
+
+// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
+func SensitiveWordContains(text string) (bool, []string) {
+ // 构建一个AC自动机
+ m := initAc()
+ hits := m.MultiPatternSearch([]rune(text), false)
+ if len(hits) > 0 {
+ words := make([]string, 0)
+ for _, hit := range hits {
+ words = append(words, string(hit.Word))
+ }
+ return true, words
+ }
+ return false, nil
+}
+
+// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
+func SensitiveWordReplace(text string) (bool, string) {
+ m := initAc()
+ hits := m.MultiPatternSearch([]rune(text), false)
+ if len(hits) > 0 {
+ for _, hit := range hits {
+ pos := hit.Pos
+ word := string(hit.Word)
+ text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):]
+ }
+ return true, text
+ }
+ return false, text
+}
+
+func initAc() *goahocorasick.Machine {
+ m := new(goahocorasick.Machine)
+ dict := readRunes()
+ if err := m.Build(dict); err != nil {
+ fmt.Println(err)
+ return nil
+ }
+ return m
+}
+
+func readRunes() [][]rune {
+ var dict [][]rune
+
+ for _, word := range constant.SensitiveWords {
+ l := bytes.TrimSpace([]byte(word))
+ dict = append(dict, bytes.Runes(l))
+ }
+
+ return dict
+}
diff --git a/service/token_counter.go b/service/token_counter.go
index a8e82ed..a04be59 100644
--- a/service/token_counter.go
+++ b/service/token_counter.go
@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
return tiles*170 + 85, nil
}
-func CountTokenMessages(messages []dto.Message, model string) (int, error) {
+func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -144,6 +144,13 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err
} else {
+ if checkSensitive {
+ contains, words := SensitiveWordContains(stringContent)
+ if contains {
+ err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
+ return 0, err
+ }
+ }
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
@@ -190,29 +197,37 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
return tokenNum, nil
}
-func CountTokenInput(input any, model string) int {
+func CountTokenInput(input any, model string, check bool) (int, error) {
switch v := input.(type) {
case string:
- return CountTokenText(v, model)
+ return CountTokenText(v, model, check)
case []string:
text := ""
for _, s := range v {
text += s
}
- return CountTokenText(text, model)
+ return CountTokenText(text, model, check)
}
- return 0
+ return 0, errors.New("unsupported input type")
}
-func CountAudioToken(text string, model string) int {
+func CountAudioToken(text string, model string, check bool) (int, error) {
if strings.HasPrefix(model, "tts") {
- return utf8.RuneCountInString(text)
+ return utf8.RuneCountInString(text), nil
} else {
- return CountTokenText(text, model)
+ return CountTokenText(text, model, check)
}
}
-func CountTokenText(text string, model string) int {
+// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
+func CountTokenText(text string, model string, check bool) (int, error) {
+ var err error
+ if check {
+ contains, words := SensitiveWordContains(text)
+ if contains {
+ err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
+ }
+ }
tokenEncoder := getTokenEncoder(model)
- return getTokenNum(tokenEncoder, text)
+ return getTokenNum(tokenEncoder, text), err
}
diff --git a/service/usage_helpr.go b/service/usage_helpr.go
index c1fcfb5..53a5c04 100644
--- a/service/usage_helpr.go
+++ b/service/usage_helpr.go
@@ -1,27 +1,26 @@
package service
import (
- "errors"
"one-api/dto"
- "one-api/relay/constant"
)
-func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
- switch relayMode {
- case constant.RelayModeChatCompletions:
- return CountTokenMessages(textRequest.Messages, textRequest.Model)
- case constant.RelayModeCompletions:
- return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
- case constant.RelayModeModerations:
- return CountTokenInput(textRequest.Input, textRequest.Model), nil
- }
- return 0, errors.New("unknown relay mode")
-}
+//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
+// switch relayMode {
+// case constant.RelayModeChatCompletions:
+// return CountTokenMessages(textRequest.Messages, textRequest.Model)
+// case constant.RelayModeCompletions:
+// return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
+// case constant.RelayModeModerations:
+// return CountTokenInput(textRequest.Input, textRequest.Model), nil
+// }
+// return 0, errors.New("unknown relay mode")
+//}
-func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
- usage.CompletionTokens = CountTokenText(responseText, modeName)
+ ctkm, err := CountTokenText(responseText, modeName, false)
+ usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return usage
+ return usage, err
}
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index 874c552..c006106 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -23,6 +23,11 @@ const OperationSetting = () => {
LogConsumeEnabled: '',
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
+ CheckSensitiveEnabled: '',
+ CheckSensitiveOnPromptEnabled: '',
+ CheckSensitiveOnCompletionEnabled: '',
+ StopOnSensitiveEnabled: '',
+ SensitiveWords: '',
MjNotifyEnabled: '',
DrawingEnabled: '',
DataExportEnabled: '',
@@ -130,6 +135,11 @@ const OperationSetting = () => {
await updateOption('ModelPrice', inputs.ModelPrice);
}
break;
+ case 'words':
+ if (originInputs['SensitiveWords'] !== inputs.SensitiveWords) {
+ await updateOption('SensitiveWords', inputs.SensitiveWords);
+ }
+ break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
@@ -273,6 +283,51 @@ const OperationSetting = () => {
/>
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ submitConfig('words').then();
+ }}>保存屏蔽词设置
+
From 64b9d3b58c03a35d2b616c8a7e207f1b9fe807a1 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 19:00:51 +0800
Subject: [PATCH 02/11] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC?=
=?UTF-8?q?=E5=AE=B9=E7=94=9F=E6=88=90=E5=86=85=E5=AE=B9=E6=A3=80=E6=9F=A5?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
common/str.go | 12 ++++
controller/channel-test.go | 2 +-
dto/sensitive.go | 6 ++
dto/text_response.go | 4 +-
relay/channel/adapter.go | 2 +-
relay/channel/ali/adaptor.go | 2 +-
relay/channel/baidu/adaptor.go | 2 +-
relay/channel/claude/adaptor.go | 2 +-
relay/channel/gemini/adaptor.go | 2 +-
relay/channel/ollama/adaptor.go | 4 +-
relay/channel/openai/adaptor.go | 4 +-
relay/channel/openai/relay-openai.go | 101 +++++++++++++++++++--------
relay/channel/palm/adaptor.go | 2 +-
relay/channel/perplexity/adaptor.go | 4 +-
relay/channel/tencent/adaptor.go | 2 +-
relay/channel/xunfei/adaptor.go | 6 +-
relay/channel/zhipu/adaptor.go | 2 +-
relay/channel/zhipu_4v/adaptor.go | 4 +-
relay/common/relay_utils.go | 2 +-
relay/relay-text.go | 25 +++++--
service/sensitive.go | 14 ++--
21 files changed, 141 insertions(+), 63 deletions(-)
create mode 100644 dto/sensitive.go
diff --git a/common/str.go b/common/str.go
index d16f7a0..ddf8375 100644
--- a/common/str.go
+++ b/common/str.go
@@ -36,3 +36,15 @@ func SundaySearch(text string, pattern string) bool {
}
return false // 如果没有找到匹配,返回-1
}
+
+func RemoveDuplicate(s []string) []string {
+ result := make([]string, 0, len(s))
+ temp := map[string]struct{}{}
+ for _, item := range s {
+ if _, ok := temp[item]; !ok {
+ temp[item] = struct{}{}
+ result = append(result, item)
+ }
+ }
+ return result
+}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 4a2906b..2eee554 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
- usage, respErr := adaptor.DoResponse(c, resp, meta)
+ usage, respErr, _ := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
diff --git a/dto/sensitive.go b/dto/sensitive.go
new file mode 100644
index 0000000..0bfbc6f
--- /dev/null
+++ b/dto/sensitive.go
@@ -0,0 +1,6 @@
+package dto
+
+type SensitiveResponse struct {
+ SensitiveWords []string `json:"sensitive_words"`
+ Content string `json:"content"`
+}
diff --git a/dto/text_response.go b/dto/text_response.go
index 81d0748..7b94ed5 100644
--- a/dto/text_response.go
+++ b/dto/text_response.go
@@ -1,9 +1,9 @@
package dto
type TextResponse struct {
- Choices []OpenAITextResponseChoice `json:"choices"`
+ Choices []*OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
- Error OpenAIError `json:"error"`
+ Error *OpenAIError `json:"error,omitempty"`
}
type OpenAITextResponseChoice struct {
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
index d3886d5..935a0f0 100644
--- a/relay/channel/adapter.go
+++ b/relay/channel/adapter.go
@@ -15,7 +15,7 @@ type Adaptor interface {
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
- DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
+ DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse)
GetModelList() []string
GetChannelName() string
}
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
index bfe83db..155d5a0 100644
--- a/relay/channel/ali/adaptor.go
+++ b/relay/channel/ali/adaptor.go
@@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
err, usage = aliStreamHandler(c, resp)
} else {
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
index d2571dc..d8c96aa 100644
--- a/relay/channel/baidu/adaptor.go
+++ b/relay/channel/baidu/adaptor.go
@@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
err, usage = baiduStreamHandler(c, resp)
} else {
diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go
index 45efd01..ee5a4c0 100644
--- a/relay/channel/claude/adaptor.go
+++ b/relay/channel/claude/adaptor.go
@@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
} else {
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index a275175..5f78829 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index 55edf7a..6ef2f30 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -39,13 +39,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+ err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index 417dbce..d9c52f8 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -71,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
var responseText string
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
- err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+ err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index a3a2634..4409555 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -4,8 +4,11 @@ import (
"bufio"
"bytes"
"encoding/json"
+ "errors"
+ "fmt"
"github.com/gin-gonic/gin"
"io"
+ "log"
"net/http"
"one-api/common"
"one-api/constant"
@@ -18,6 +21,7 @@ import (
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
+ checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -37,11 +41,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
-
go func() {
wg.Add(1)
defer wg.Done()
- var streamItems []string
+ var streamItems []string // store stream items
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
@@ -50,11 +53,20 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
+ sensitive := false
+ if checkSensitive {
+ // check sensitive
+ sensitive, _, data = service.SensitiveWordReplace(data, constant.StopOnSensitiveEnabled)
+ }
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data)
}
+ if sensitive && constant.StopOnSensitiveEnabled {
+ dataChan <- "data: [DONE]"
+ break
+ }
}
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
@@ -112,50 +124,48 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
return nil, responseTextBuilder.String()
}
-func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var textResponse dto.TextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = resp.Body.Close()
if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
}
- if textResponse.Error.Type != "" {
+ log.Printf("textResponse: %+v", textResponse)
+ if textResponse.Error != nil {
return &dto.OpenAIErrorWithStatusCode{
- Error: textResponse.Error,
+ Error: *textResponse.Error,
StatusCode: resp.StatusCode,
- }, nil
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }, nil, nil
}
- if textResponse.Usage.TotalTokens == 0 {
+ checkSensitive := constant.ShouldCheckCompletionSensitive()
+ sensitiveWords := make([]string, 0)
+ triggerSensitive := false
+
+ if textResponse.Usage.TotalTokens == 0 || checkSensitive {
completionTokens := 0
for _, choice := range textResponse.Choices {
- ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive())
+ stringContent := string(choice.Message.Content)
+ ctkm, _ := service.CountTokenText(stringContent, model, false)
completionTokens += ctkm
+ if checkSensitive {
+ sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
+ if sensitive {
+ triggerSensitive = true
+ msg := choice.Message
+ msg.Content = common.StringToByteSlice(stringContent)
+ choice.Message = msg
+ sensitiveWords = append(sensitiveWords, words...)
+ }
+ }
}
textResponse.Usage = dto.Usage{
PromptTokens: promptTokens,
@@ -163,5 +173,36 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
TotalTokens: promptTokens + completionTokens,
}
}
- return nil, &textResponse.Usage
+
+ if constant.StopOnSensitiveEnabled {
+
+ } else {
+ responseBody, err = json.Marshal(textResponse)
+ // Reset response body
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ // We shouldn't set the header before we parse the response body, because the parse part may fail.
+ // And then we will have to send an error response, but in this case, the header has already been set.
+ // So the httpClient will be confused by the response.
+ // For example, Postman will report error, and we cannot check the response at all.
+ for k, v := range resp.Header {
+ c.Writer.Header().Set(k, v[0])
+ }
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = io.Copy(c.Writer, resp.Body)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+ }
+ }
+
+ if checkSensitive && triggerSensitive {
+ sensitiveWords = common.RemoveDuplicate(sensitiveWords)
+ return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
+ SensitiveWords: sensitiveWords,
+ }
+ }
+ return nil, &textResponse.Usage, nil
}
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
index 4f59a44..2129ee3 100644
--- a/relay/channel/palm/adaptor.go
+++ b/relay/channel/palm/adaptor.go
@@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go
index 24765ff..8d056ec 100644
--- a/relay/channel/perplexity/adaptor.go
+++ b/relay/channel/perplexity/adaptor.go
@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+ err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
index 470ec14..fa3da0c 100644
--- a/relay/channel/tencent/adaptor.go
+++ b/relay/channel/tencent/adaptor.go
@@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
index 79a4b12..9ab2818 100644
--- a/relay/channel/xunfei/adaptor.go
+++ b/relay/channel/xunfei/adaptor.go
@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return dummyResp, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 {
- return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+ return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil
}
if a.request == nil {
- return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
+ return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil
}
if info.IsStream {
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
index 6f2d186..69c45a8 100644
--- a/relay/channel/zhipu/adaptor.go
+++ b/relay/channel/zhipu/adaptor.go
@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
err, usage = zhipuStreamHandler(c, resp)
} else {
diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go
index 1b8866b..dded3c5 100644
--- a/relay/channel/zhipu_4v/adaptor.go
+++ b/relay/channel/zhipu_4v/adaptor.go
@@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+ err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go
index f89b8b9..71b8ba7 100644
--- a/relay/common/relay_utils.go
+++ b/relay/common/relay_utils.go
@@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
if err != nil {
return
}
- OpenAIErrorWithStatusCode.Error = textResponse.Error
+ OpenAIErrorWithStatusCode.Error = *textResponse.Error
return
}
diff --git a/relay/relay-text.go b/relay/relay-text.go
index 8a38a81..b9c1e7a 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -162,12 +162,21 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.RelayErrorHandler(resp)
}
- usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
+ usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
- return openaiErr
+ if sensitiveResp == nil { // 如果没有敏感词检查结果
+ returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+ return openaiErr
+ } else {
+ // 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额
+ postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp)
+ if constant.StopOnSensitiveEnabled { // 是否直接返回错误
+ return openaiErr
+ }
+ return nil
+ }
}
- postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
+ postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil)
return nil
}
@@ -243,7 +252,10 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
}
}
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
+ usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
+ modelPrice float64, sensitiveResp *dto.SensitiveResponse) {
+
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
@@ -277,6 +289,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
} else {
+ if sensitiveResp != nil {
+ logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
+ }
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
diff --git a/service/sensitive.go b/service/sensitive.go
index 6b77849..57c667f 100644
--- a/service/sensitive.go
+++ b/service/sensitive.go
@@ -24,18 +24,21 @@ func SensitiveWordContains(text string) (bool, []string) {
}
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
-func SensitiveWordReplace(text string) (bool, string) {
+func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
+ text = strings.ToLower(text)
m := initAc()
- hits := m.MultiPatternSearch([]rune(text), false)
+ hits := m.MultiPatternSearch([]rune(text), returnImmediately)
if len(hits) > 0 {
+ words := make([]string, 0)
for _, hit := range hits {
pos := hit.Pos
word := string(hit.Word)
- text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):]
+ text = text[:pos] + " *###* " + text[pos+len(word):]
+ words = append(words, word)
}
- return true, text
+ return true, words, text
}
- return false, text
+ return false, nil, text
}
func initAc() *goahocorasick.Machine {
@@ -52,6 +55,7 @@ func readRunes() [][]rune {
var dict [][]rune
for _, word := range constant.SensitiveWords {
+ word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))
}
From 2db4282666dcd46fdb2ed7543617a65595fa2c3e Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 20:15:32 +0800
Subject: [PATCH 03/11] =?UTF-8?q?feat:=20=E4=BF=9D=E7=95=99=E5=8A=9F?=
=?UTF-8?q?=E8=83=BD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
constant/sensitive.go | 3 +++
model/option.go | 3 +++
web/src/components/OperationSetting.js | 14 ++++++++++++++
3 files changed, 20 insertions(+)
diff --git a/constant/sensitive.go b/constant/sensitive.go
index 10ecfe6..fabca67 100644
--- a/constant/sensitive.go
+++ b/constant/sensitive.go
@@ -9,6 +9,9 @@ var CheckSensitiveOnCompletionEnabled = true
// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
var StopOnSensitiveEnabled = true
+// StreamCacheQueueLength 流模式缓存队列长度,0表示无缓存
+var StreamCacheQueueLength = 0
+
// SensitiveWords 敏感词
// var SensitiveWords []string
var SensitiveWords = []string{
diff --git a/model/option.go b/model/option.go
index 7422da1..46e41da 100644
--- a/model/option.go
+++ b/model/option.go
@@ -95,6 +95,7 @@ func InitOptionMap() {
common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
+ common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -288,6 +289,8 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
constant.SensitiveWordsFromString(value)
+ case "StreamCacheQueueLength":
+ constant.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
return err
}
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index c006106..728eab0 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -10,6 +10,7 @@ const OperationSetting = () => {
QuotaForInvitee: 0,
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
+ StreamCacheQueueLength: 0,
ModelRatio: '',
ModelPrice: '',
GroupRatio: '',
@@ -307,6 +308,8 @@ const OperationSetting = () => {
name="CheckSensitiveOnCompletionEnabled"
onChange={handleInputChange}
/>
+
+
{
onChange={handleInputChange}
/>
+ {/**/}
+ {/* */}
+ {/**/}
Date: Wed, 20 Mar 2024 20:26:34 +0800
Subject: [PATCH 04/11] fix: fix error
---
relay/channel/openai/relay-openai.go | 2 +-
service/sensitive.go | 6 +++---
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 4409555..9ec4260 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -56,7 +56,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
sensitive := false
if checkSensitive {
// check sensitive
- sensitive, _, data = service.SensitiveWordReplace(data, constant.StopOnSensitiveEnabled)
+ sensitive, _, data = service.SensitiveWordReplace(data, false)
}
dataChan <- data
data = data[6:]
diff --git a/service/sensitive.go b/service/sensitive.go
index 57c667f..28618a0 100644
--- a/service/sensitive.go
+++ b/service/sensitive.go
@@ -25,15 +25,15 @@ func SensitiveWordContains(text string) (bool, []string) {
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
- text = strings.ToLower(text)
+ checkText := strings.ToLower(text)
m := initAc()
- hits := m.MultiPatternSearch([]rune(text), returnImmediately)
+ hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
pos := hit.Pos
word := string(hit.Word)
- text = text[:pos] + " *###* " + text[pos+len(word):]
+ text = text[:pos] + "*###*" + text[pos+len(word):]
words = append(words, word)
}
return true, words, text
From eb6257a8d892c9a6ee711890d89e821afa725b59 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 20:28:00 +0800
Subject: [PATCH 05/11] fix: fix SensitiveWordContains not working
---
service/sensitive.go | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/service/sensitive.go b/service/sensitive.go
index 28618a0..b216376 100644
--- a/service/sensitive.go
+++ b/service/sensitive.go
@@ -10,9 +10,10 @@ import (
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
func SensitiveWordContains(text string) (bool, []string) {
+ checkText := strings.ToLower(text)
// 构建一个AC自动机
m := initAc()
- hits := m.MultiPatternSearch([]rune(text), false)
+ hits := m.MultiPatternSearch([]rune(checkText), false)
if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits {
From a232afe9fdeea02452f96c3a85194344cdc29270 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 20:36:55 +0800
Subject: [PATCH 06/11] =?UTF-8?q?feat:=20=E7=BB=9F=E4=B8=80=E9=94=99?=
=?UTF-8?q?=E8=AF=AF=E6=8F=90=E7=A4=BA?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
relay/channel/claude/relay-claude.go | 2 +-
relay/channel/gemini/relay-gemini.go | 2 +-
relay/channel/openai/relay-openai.go | 2 +-
relay/channel/palm/relay-palm.go | 2 +-
relay/relay-audio.go | 4 ++--
relay/relay-text.go | 18 +++++++++++-------
service/token_counter.go | 28 +++++++++++++++++-----------
service/usage_helpr.go | 2 +-
8 files changed, 35 insertions(+), 25 deletions(-)
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index 1027faa..a56c1bb 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -313,7 +313,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
}, nil
}
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
- completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
+ completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index b199178..31badd8 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -257,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
- completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
+ completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 9ec4260..7e36861 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -154,7 +154,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
completionTokens := 0
for _, choice := range textResponse.Choices {
stringContent := string(choice.Message.Content)
- ctkm, _ := service.CountTokenText(stringContent, model, false)
+ ctkm, _, _ := service.CountTokenText(stringContent, model, false)
completionTokens += ctkm
if checkSensitive {
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go
index b3607c0..4028269 100644
--- a/relay/channel/palm/relay-palm.go
+++ b/relay/channel/palm/relay-palm.go
@@ -157,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
- completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
+ completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
diff --git a/relay/relay-audio.go b/relay/relay-audio.go
index d68550e..1c0f868 100644
--- a/relay/relay-audio.go
+++ b/relay/relay-audio.go
@@ -67,7 +67,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if strings.HasPrefix(audioRequest.Model, "tts-1") {
- promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
+ promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}
@@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = promptTokens
} else {
- quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
+ quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
diff --git a/relay/relay-text.go b/relay/relay-text.go
index b9c1e7a..a8ba0e8 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -98,10 +98,13 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
- promptTokens, err := getPromptTokens(textRequest, relayInfo)
+ promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
+ if sensitiveTrigger {
+ return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
+ }
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
@@ -180,25 +183,26 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return nil
}
-func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
+func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) {
var promptTokens int
var err error
+ var sensitiveTrigger bool
checkSensitive := constant.ShouldCheckPromptSensitive()
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
- promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
+ promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
case relayconstant.RelayModeCompletions:
- promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
+ promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
case relayconstant.RelayModeModerations:
- promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
+ promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
case relayconstant.RelayModeEmbeddings:
- promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
+ promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
default:
err = errors.New("unknown relay mode")
promptTokens = 0
}
info.PromptTokens = promptTokens
- return promptTokens, err
+ return promptTokens, err, sensitiveTrigger
}
// 预扣费并返回用户剩余配额
diff --git a/service/token_counter.go b/service/token_counter.go
index a04be59..4769dab 100644
--- a/service/token_counter.go
+++ b/service/token_counter.go
@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
return tiles*170 + 85, nil
}
-func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
+func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -142,13 +142,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
var stringContent string
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
- return 0, err
+ return 0, err, false
} else {
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
- return 0, err
+ return 0, err, true
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
@@ -181,7 +181,7 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
imageTokenNum, err = getImageToken(&imageUrl)
}
if err != nil {
- return 0, err
+ return 0, err, false
}
}
tokenNum += imageTokenNum
@@ -194,10 +194,10 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
- return tokenNum, nil
+ return tokenNum, nil, false
}
-func CountTokenInput(input any, model string, check bool) (int, error) {
+func CountTokenInput(input any, model string, check bool) (int, error, bool) {
switch v := input.(type) {
case string:
return CountTokenText(v, model, check)
@@ -208,26 +208,32 @@ func CountTokenInput(input any, model string, check bool) (int, error) {
}
return CountTokenText(text, model, check)
}
- return 0, errors.New("unsupported input type")
+ return 0, errors.New("unsupported input type"), false
}
-func CountAudioToken(text string, model string, check bool) (int, error) {
+func CountAudioToken(text string, model string, check bool) (int, error, bool) {
if strings.HasPrefix(model, "tts") {
- return utf8.RuneCountInString(text), nil
+ contains, words := SensitiveWordContains(text)
+ if contains {
+ return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
+ }
+ return utf8.RuneCountInString(text), nil, false
} else {
return CountTokenText(text, model, check)
}
}
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTokenText(text string, model string, check bool) (int, error) {
+func CountTokenText(text string, model string, check bool) (int, error, bool) {
var err error
+ var trigger bool
if check {
contains, words := SensitiveWordContains(text)
if contains {
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
+ trigger = true
}
}
tokenEncoder := getTokenEncoder(model)
- return getTokenNum(tokenEncoder, text), err
+ return getTokenNum(tokenEncoder, text), err, trigger
}
diff --git a/service/usage_helpr.go b/service/usage_helpr.go
index 53a5c04..460ac56 100644
--- a/service/usage_helpr.go
+++ b/service/usage_helpr.go
@@ -19,7 +19,7 @@ import (
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
- ctkm, err := CountTokenText(responseText, modeName, false)
+ ctkm, err, _ := CountTokenText(responseText, modeName, false)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err
From b736de71896fe4d01e7e0ff019b6a79379c7ae1c Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 21:28:45 +0800
Subject: [PATCH 07/11] fix: claude panic
---
relay/channel/claude/relay-claude.go | 3 +++
1 file changed, 3 insertions(+)
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index a56c1bb..6c89f34 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -150,6 +150,9 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
claudeUsage = &claudeResponse.Usage
}
}
+ if claudeUsage == nil {
+ claudeUsage = &ClaudeUsage{}
+ }
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
}
From dd71946047e9bd31e6947609ea98292c683af5c1 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 21:32:33 +0800
Subject: [PATCH 08/11] fix: claude panic
---
relay/channel/claude/relay-claude.go | 1 +
1 file changed, 1 insertion(+)
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index 6c89f34..f1a3c39 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -199,6 +199,7 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage
+ usage = &dto.Usage{}
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
From d5e93e788d11a0a02221243113efeba7a8fd9def Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 21:49:54 +0800
Subject: [PATCH 09/11] fix: midjourneys table
---
model/main.go | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/model/main.go b/model/main.go
index 62a02a2..7b1cd3d 100644
--- a/model/main.go
+++ b/model/main.go
@@ -94,7 +94,10 @@ func InitDB() (err error) {
return nil
}
if common.UsingMySQL {
- _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
+ _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
+ _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
+ _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
+ _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
From c7658b70d1202a10cb673a205a21a35287408e7d Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 22:33:22 +0800
Subject: [PATCH 10/11] =?UTF-8?q?fix:=20=E6=95=8F=E6=84=9F=E8=AF=8D?=
=?UTF-8?q?=E5=BA=93=E4=B8=BA=E7=A9=BA=E6=97=B6=EF=BC=8C=E4=B8=8D=E6=A3=80?=
=?UTF-8?q?=E6=B5=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
constant/sensitive.go | 8 +++++++-
service/sensitive.go | 6 ++++++
2 files changed, 13 insertions(+), 1 deletion(-)
diff --git a/constant/sensitive.go b/constant/sensitive.go
index fabca67..4a9a7e6 100644
--- a/constant/sensitive.go
+++ b/constant/sensitive.go
@@ -23,7 +23,13 @@ func SensitiveWordsToString() string {
}
func SensitiveWordsFromString(s string) {
- SensitiveWords = strings.Split(s, "\n")
+ sw := strings.Split(s, "\n")
+ for _, w := range sw {
+ w = strings.TrimSpace(w)
+ if w != "" {
+ SensitiveWords = append(SensitiveWords, w)
+ }
+ }
}
func ShouldCheckPromptSensitive() bool {
diff --git a/service/sensitive.go b/service/sensitive.go
index b216376..dbb0887 100644
--- a/service/sensitive.go
+++ b/service/sensitive.go
@@ -10,6 +10,9 @@ import (
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
func SensitiveWordContains(text string) (bool, []string) {
+ if len(constant.SensitiveWords) == 0 {
+ return false, nil
+ }
checkText := strings.ToLower(text)
// 构建一个AC自动机
m := initAc()
@@ -26,6 +29,9 @@ func SensitiveWordContains(text string) (bool, []string) {
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
+ if len(constant.SensitiveWords) == 0 {
+ return false, nil, text
+ }
checkText := strings.ToLower(text)
m := initAc()
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
From d7e25e16046dea684982907be473c09a9c9020a4 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 20 Mar 2024 23:58:42 +0800
Subject: [PATCH 11/11] fix: fix SensitiveWords load error
---
constant/sensitive.go | 1 +
1 file changed, 1 insertion(+)
diff --git a/constant/sensitive.go b/constant/sensitive.go
index 4a9a7e6..8d8e15f 100644
--- a/constant/sensitive.go
+++ b/constant/sensitive.go
@@ -23,6 +23,7 @@ func SensitiveWordsToString() string {
}
func SensitiveWordsFromString(s string) {
+ SensitiveWords = []string{}
sw := strings.Split(s, "\n")
for _, w := range sw {
w = strings.TrimSpace(w)