mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	feat: 初步兼容敏感词过滤
This commit is contained in:
		
							
								
								
									
										38
									
								
								common/str.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								common/str.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
func SundaySearch(text string, pattern string) bool {
 | 
			
		||||
	// 计算偏移表
 | 
			
		||||
	offset := make(map[rune]int)
 | 
			
		||||
	for i, c := range pattern {
 | 
			
		||||
		offset[c] = len(pattern) - i
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 文本串长度和模式串长度
 | 
			
		||||
	n, m := len(text), len(pattern)
 | 
			
		||||
 | 
			
		||||
	// 主循环,i表示当前对齐的文本串位置
 | 
			
		||||
	for i := 0; i <= n-m; {
 | 
			
		||||
		// 检查子串
 | 
			
		||||
		j := 0
 | 
			
		||||
		for j < m && text[i+j] == pattern[j] {
 | 
			
		||||
			j++
 | 
			
		||||
		}
 | 
			
		||||
		// 如果完全匹配,返回匹配位置
 | 
			
		||||
		if j == m {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 如果还有剩余字符,则检查下一位字符在偏移表中的值
 | 
			
		||||
		if i+m < n {
 | 
			
		||||
			next := rune(text[i+m])
 | 
			
		||||
			if val, ok := offset[next]; ok {
 | 
			
		||||
				i += val // 存在于偏移表中,进行跳跃
 | 
			
		||||
			} else {
 | 
			
		||||
				i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false // 如果没有找到匹配,返回-1
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										32
									
								
								constant/sensitive.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								constant/sensitive.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
			
		||||
package constant
 | 
			
		||||
 | 
			
		||||
import "strings"
 | 
			
		||||
 | 
			
		||||
var CheckSensitiveEnabled = true
 | 
			
		||||
var CheckSensitiveOnPromptEnabled = true
 | 
			
		||||
var CheckSensitiveOnCompletionEnabled = true
 | 
			
		||||
 | 
			
		||||
// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
 | 
			
		||||
var StopOnSensitiveEnabled = true
 | 
			
		||||
 | 
			
		||||
// SensitiveWords 敏感词
 | 
			
		||||
// var SensitiveWords []string
 | 
			
		||||
var SensitiveWords = []string{
 | 
			
		||||
	"test",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SensitiveWordsToString() string {
 | 
			
		||||
	return strings.Join(SensitiveWords, "\n")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SensitiveWordsFromString(s string) {
 | 
			
		||||
	SensitiveWords = strings.Split(s, "\n")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ShouldCheckPromptSensitive() bool {
 | 
			
		||||
	return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ShouldCheckCompletionSensitive() bool {
 | 
			
		||||
	return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@@ -4,6 +4,7 @@ module one-api
 | 
			
		||||
go 1.18
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
 | 
			
		||||
	github.com/gin-contrib/cors v1.4.0
 | 
			
		||||
	github.com/gin-contrib/gzip v0.0.6
 | 
			
		||||
	github.com/gin-contrib/sessions v0.0.5
 | 
			
		||||
@@ -27,6 +28,7 @@ require (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
 | 
			
		||||
	github.com/bytedance/sonic v1.9.1 // indirect
 | 
			
		||||
	github.com/cespare/xxhash/v2 v2.1.2 // indirect
 | 
			
		||||
	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										20
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								go.sum
									
									
									
									
									
								
							@@ -1,3 +1,7 @@
 | 
			
		||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
 | 
			
		||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
 | 
			
		||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
 | 
			
		||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
 | 
			
		||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 | 
			
		||||
@@ -15,8 +19,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
 | 
			
		||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
 | 
			
		||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
 | 
			
		||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
 | 
			
		||||
@@ -37,7 +39,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
 | 
			
		||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
 | 
			
		||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
 | 
			
		||||
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
 | 
			
		||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
 | 
			
		||||
@@ -48,8 +49,6 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
 | 
			
		||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
 | 
			
		||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
 | 
			
		||||
@@ -105,8 +104,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
 | 
			
		||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
 | 
			
		||||
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
 | 
			
		||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
 | 
			
		||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
 | 
			
		||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
 | 
			
		||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
 | 
			
		||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
 | 
			
		||||
@@ -154,9 +151,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
 | 
			
		||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 | 
			
		||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
 | 
			
		||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 | 
			
		||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 | 
			
		||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
 | 
			
		||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 | 
			
		||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
 | 
			
		||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
 | 
			
		||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
 | 
			
		||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
 | 
			
		||||
@@ -175,8 +171,6 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
 | 
			
		||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
 | 
			
		||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
 | 
			
		||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
 | 
			
		||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
 | 
			
		||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
 | 
			
		||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
 | 
			
		||||
@@ -184,8 +178,6 @@ golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9
 | 
			
		||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
 | 
			
		||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
 | 
			
		||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 | 
			
		||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
 | 
			
		||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
 | 
			
		||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
 | 
			
		||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
 | 
			
		||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
 | 
			
		||||
@@ -200,8 +192,6 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
 | 
			
		||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
 | 
			
		||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 | 
			
		||||
 
 | 
			
		||||
@@ -90,6 +90,11 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
 | 
			
		||||
	common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
 | 
			
		||||
	common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
 | 
			
		||||
	common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
 | 
			
		||||
	common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
 | 
			
		||||
	common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
 | 
			
		||||
	common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
 | 
			
		||||
	common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
 | 
			
		||||
 | 
			
		||||
	common.OptionMapRWMutex.Unlock()
 | 
			
		||||
	loadOptionsFromDatabase()
 | 
			
		||||
@@ -185,6 +190,14 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
			common.DefaultCollapseSidebar = boolValue
 | 
			
		||||
		case "MjNotifyEnabled":
 | 
			
		||||
			constant.MjNotifyEnabled = boolValue
 | 
			
		||||
		case "CheckSensitiveEnabled":
 | 
			
		||||
			constant.CheckSensitiveEnabled = boolValue
 | 
			
		||||
		case "CheckSensitiveOnPromptEnabled":
 | 
			
		||||
			constant.CheckSensitiveOnPromptEnabled = boolValue
 | 
			
		||||
		case "CheckSensitiveOnCompletionEnabled":
 | 
			
		||||
			constant.CheckSensitiveOnCompletionEnabled = boolValue
 | 
			
		||||
		case "StopOnSensitiveEnabled":
 | 
			
		||||
			constant.StopOnSensitiveEnabled = boolValue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	switch key {
 | 
			
		||||
@@ -273,6 +286,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
		common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
 | 
			
		||||
	case "QuotaPerUnit":
 | 
			
		||||
		common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
 | 
			
		||||
	case "SensitiveWords":
 | 
			
		||||
		constant.SensitiveWordsFromString(value)
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -194,7 +195,7 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 | 
			
		||||
 | 
			
		||||
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 | 
			
		||||
	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 | 
			
		||||
	var usage dto.Usage
 | 
			
		||||
	var usage *dto.Usage
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	createdTime := common.GetTimestamp()
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
@@ -277,13 +278,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if requestMode == RequestModeCompletion {
 | 
			
		||||
		usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		if usage.CompletionTokens == 0 {
 | 
			
		||||
			usage = *service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
 | 
			
		||||
			usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 | 
			
		||||
@@ -312,7 +313,10 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
 | 
			
		||||
	completionTokens := service.CountTokenText(claudeResponse.Completion, model)
 | 
			
		||||
	completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	usage := dto.Usage{}
 | 
			
		||||
	if requestMode == RequestModeCompletion {
 | 
			
		||||
		usage.PromptTokens = promptTokens
 | 
			
		||||
 
 | 
			
		||||
@@ -51,7 +51,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = geminiChatStreamHandler(c, resp)
 | 
			
		||||
		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
@@ -256,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
 | 
			
		||||
	completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model)
 | 
			
		||||
	completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
 | 
			
		||||
	usage := dto.Usage{
 | 
			
		||||
		PromptTokens:     promptTokens,
 | 
			
		||||
		CompletionTokens: completionTokens,
 | 
			
		||||
 
 | 
			
		||||
@@ -43,7 +43,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 | 
			
		||||
		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -75,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
 | 
			
		||||
		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
@@ -153,7 +154,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 | 
			
		||||
	if textResponse.Usage.TotalTokens == 0 {
 | 
			
		||||
		completionTokens := 0
 | 
			
		||||
		for _, choice := range textResponse.Choices {
 | 
			
		||||
			completionTokens += service.CountTokenText(string(choice.Message.Content), model)
 | 
			
		||||
			ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive())
 | 
			
		||||
			completionTokens += ctkm
 | 
			
		||||
		}
 | 
			
		||||
		textResponse.Usage = dto.Usage{
 | 
			
		||||
			PromptTokens:     promptTokens,
 | 
			
		||||
 
 | 
			
		||||
@@ -43,7 +43,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = palmStreamHandler(c, resp)
 | 
			
		||||
		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
@@ -156,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
 | 
			
		||||
	completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model)
 | 
			
		||||
	completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
 | 
			
		||||
	usage := dto.Usage{
 | 
			
		||||
		PromptTokens:     promptTokens,
 | 
			
		||||
		CompletionTokens: completionTokens,
 | 
			
		||||
 
 | 
			
		||||
@@ -47,7 +47,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 | 
			
		||||
		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -57,7 +57,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = tencentStreamHandler(c, resp)
 | 
			
		||||
		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = tencentHandler(c, resp)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -48,7 +48,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 | 
			
		||||
		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
@@ -62,8 +63,16 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
			return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	promptTokens := 0
 | 
			
		||||
	preConsumedTokens := common.PreConsumedQuota
 | 
			
		||||
	if strings.HasPrefix(audioRequest.Model, "tts-1") {
 | 
			
		||||
		promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		preConsumedTokens = promptTokens
 | 
			
		||||
	}
 | 
			
		||||
	modelRatio := common.GetModelRatio(audioRequest.Model)
 | 
			
		||||
	groupRatio := common.GetGroupRatio(group)
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
@@ -161,12 +170,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
		go func() {
 | 
			
		||||
			useTimeSeconds := time.Now().Unix() - startTime.Unix()
 | 
			
		||||
			quota := 0
 | 
			
		||||
			var promptTokens = 0
 | 
			
		||||
			if strings.HasPrefix(audioRequest.Model, "tts-1") {
 | 
			
		||||
				quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
 | 
			
		||||
				promptTokens = quota
 | 
			
		||||
				quota = promptTokens
 | 
			
		||||
			} else {
 | 
			
		||||
				quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
 | 
			
		||||
				quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
 | 
			
		||||
			}
 | 
			
		||||
			quota = int(float64(quota) * ratio)
 | 
			
		||||
			if ratio != 0 && quota <= 0 {
 | 
			
		||||
@@ -208,6 +215,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		contains, words := service.SensitiveWordContains(audioResponse.Text)
 | 
			
		||||
		if contains {
 | 
			
		||||
			return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,7 @@ import (
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
@@ -96,6 +97,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	var preConsumedQuota int
 | 
			
		||||
	var ratio float64
 | 
			
		||||
	var modelRatio float64
 | 
			
		||||
	//err := service.SensitiveWordsCheck(textRequest)
 | 
			
		||||
	promptTokens, err := getPromptTokens(textRequest, relayInfo)
 | 
			
		||||
 | 
			
		||||
	// count messages token error 计算promptTokens错误
 | 
			
		||||
@@ -172,16 +174,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
 | 
			
		||||
	var promptTokens int
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	checkSensitive := constant.ShouldCheckPromptSensitive()
 | 
			
		||||
	switch info.RelayMode {
 | 
			
		||||
	case relayconstant.RelayModeChatCompletions:
 | 
			
		||||
		promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model)
 | 
			
		||||
		promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
 | 
			
		||||
	case relayconstant.RelayModeCompletions:
 | 
			
		||||
		promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
 | 
			
		||||
		promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
 | 
			
		||||
	case relayconstant.RelayModeModerations:
 | 
			
		||||
		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
 | 
			
		||||
		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
 | 
			
		||||
	case relayconstant.RelayModeEmbeddings:
 | 
			
		||||
		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
 | 
			
		||||
		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
 | 
			
		||||
	default:
 | 
			
		||||
		err = errors.New("unknown relay mode")
 | 
			
		||||
		promptTokens = 0
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										60
									
								
								service/sensitive.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								service/sensitive.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,60 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/anknown/ahocorasick"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
 | 
			
		||||
func SensitiveWordContains(text string) (bool, []string) {
 | 
			
		||||
	// 构建一个AC自动机
 | 
			
		||||
	m := initAc()
 | 
			
		||||
	hits := m.MultiPatternSearch([]rune(text), false)
 | 
			
		||||
	if len(hits) > 0 {
 | 
			
		||||
		words := make([]string, 0)
 | 
			
		||||
		for _, hit := range hits {
 | 
			
		||||
			words = append(words, string(hit.Word))
 | 
			
		||||
		}
 | 
			
		||||
		return true, words
 | 
			
		||||
	}
 | 
			
		||||
	return false, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
 | 
			
		||||
func SensitiveWordReplace(text string) (bool, string) {
 | 
			
		||||
	m := initAc()
 | 
			
		||||
	hits := m.MultiPatternSearch([]rune(text), false)
 | 
			
		||||
	if len(hits) > 0 {
 | 
			
		||||
		for _, hit := range hits {
 | 
			
		||||
			pos := hit.Pos
 | 
			
		||||
			word := string(hit.Word)
 | 
			
		||||
			text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):]
 | 
			
		||||
		}
 | 
			
		||||
		return true, text
 | 
			
		||||
	}
 | 
			
		||||
	return false, text
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func initAc() *goahocorasick.Machine {
 | 
			
		||||
	m := new(goahocorasick.Machine)
 | 
			
		||||
	dict := readRunes()
 | 
			
		||||
	if err := m.Build(dict); err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func readRunes() [][]rune {
 | 
			
		||||
	var dict [][]rune
 | 
			
		||||
 | 
			
		||||
	for _, word := range constant.SensitiveWords {
 | 
			
		||||
		l := bytes.TrimSpace([]byte(word))
 | 
			
		||||
		dict = append(dict, bytes.Runes(l))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return dict
 | 
			
		||||
}
 | 
			
		||||
@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
 | 
			
		||||
	return tiles*170 + 85, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CountTokenMessages(messages []dto.Message, model string) (int, error) {
 | 
			
		||||
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
 | 
			
		||||
	//recover when panic
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
	// Reference:
 | 
			
		||||
@@ -144,6 +144,13 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
 | 
			
		||||
				if err := json.Unmarshal(message.Content, &stringContent); err != nil {
 | 
			
		||||
					return 0, err
 | 
			
		||||
				} else {
 | 
			
		||||
					if checkSensitive {
 | 
			
		||||
						contains, words := SensitiveWordContains(stringContent)
 | 
			
		||||
						if contains {
 | 
			
		||||
							err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
 | 
			
		||||
							return 0, err
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					tokenNum += getTokenNum(tokenEncoder, stringContent)
 | 
			
		||||
					if message.Name != nil {
 | 
			
		||||
						tokenNum += tokensPerName
 | 
			
		||||
@@ -190,29 +197,37 @@ func CountTokenMessages(messages []dto.Message, model string) (int, error) {
 | 
			
		||||
	return tokenNum, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CountTokenInput(input any, model string) int {
 | 
			
		||||
func CountTokenInput(input any, model string, check bool) (int, error) {
 | 
			
		||||
	switch v := input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return CountTokenText(v, model)
 | 
			
		||||
		return CountTokenText(v, model, check)
 | 
			
		||||
	case []string:
 | 
			
		||||
		text := ""
 | 
			
		||||
		for _, s := range v {
 | 
			
		||||
			text += s
 | 
			
		||||
		}
 | 
			
		||||
		return CountTokenText(text, model)
 | 
			
		||||
		return CountTokenText(text, model, check)
 | 
			
		||||
	}
 | 
			
		||||
	return 0
 | 
			
		||||
	return 0, errors.New("unsupported input type")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CountAudioToken(text string, model string) int {
 | 
			
		||||
func CountAudioToken(text string, model string, check bool) (int, error) {
 | 
			
		||||
	if strings.HasPrefix(model, "tts") {
 | 
			
		||||
		return utf8.RuneCountInString(text)
 | 
			
		||||
		return utf8.RuneCountInString(text), nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return CountTokenText(text, model)
 | 
			
		||||
		return CountTokenText(text, model, check)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CountTokenText(text string, model string) int {
 | 
			
		||||
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
 | 
			
		||||
func CountTokenText(text string, model string, check bool) (int, error) {
 | 
			
		||||
	var err error
 | 
			
		||||
	if check {
 | 
			
		||||
		contains, words := SensitiveWordContains(text)
 | 
			
		||||
		if contains {
 | 
			
		||||
			err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
	return getTokenNum(tokenEncoder, text)
 | 
			
		||||
	return getTokenNum(tokenEncoder, text), err
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,27 +1,26 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case constant.RelayModeChatCompletions:
 | 
			
		||||
		return CountTokenMessages(textRequest.Messages, textRequest.Model)
 | 
			
		||||
	case constant.RelayModeCompletions:
 | 
			
		||||
		return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
 | 
			
		||||
	case constant.RelayModeModerations:
 | 
			
		||||
		return CountTokenInput(textRequest.Input, textRequest.Model), nil
 | 
			
		||||
	}
 | 
			
		||||
	return 0, errors.New("unknown relay mode")
 | 
			
		||||
}
 | 
			
		||||
//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
 | 
			
		||||
//	switch relayMode {
 | 
			
		||||
//	case constant.RelayModeChatCompletions:
 | 
			
		||||
//		return CountTokenMessages(textRequest.Messages, textRequest.Model)
 | 
			
		||||
//	case constant.RelayModeCompletions:
 | 
			
		||||
//		return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
 | 
			
		||||
//	case constant.RelayModeModerations:
 | 
			
		||||
//		return CountTokenInput(textRequest.Input, textRequest.Model), nil
 | 
			
		||||
//	}
 | 
			
		||||
//	return 0, errors.New("unknown relay mode")
 | 
			
		||||
//}
 | 
			
		||||
 | 
			
		||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
 | 
			
		||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
 | 
			
		||||
	usage := &dto.Usage{}
 | 
			
		||||
	usage.PromptTokens = promptTokens
 | 
			
		||||
	usage.CompletionTokens = CountTokenText(responseText, modeName)
 | 
			
		||||
	ctkm, err := CountTokenText(responseText, modeName, false)
 | 
			
		||||
	usage.CompletionTokens = ctkm
 | 
			
		||||
	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 | 
			
		||||
	return usage
 | 
			
		||||
	return usage, err
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -23,6 +23,11 @@ const OperationSetting = () => {
 | 
			
		||||
    LogConsumeEnabled: '',
 | 
			
		||||
    DisplayInCurrencyEnabled: '',
 | 
			
		||||
    DisplayTokenStatEnabled: '',
 | 
			
		||||
    CheckSensitiveEnabled: '',
 | 
			
		||||
    CheckSensitiveOnPromptEnabled: '',
 | 
			
		||||
    CheckSensitiveOnCompletionEnabled: '',
 | 
			
		||||
    StopOnSensitiveEnabled: '',
 | 
			
		||||
    SensitiveWords: '',
 | 
			
		||||
    MjNotifyEnabled: '',
 | 
			
		||||
    DrawingEnabled: '',
 | 
			
		||||
    DataExportEnabled: '',
 | 
			
		||||
@@ -130,6 +135,11 @@ const OperationSetting = () => {
 | 
			
		||||
          await updateOption('ModelPrice', inputs.ModelPrice);
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case 'words':
 | 
			
		||||
        if (originInputs['SensitiveWords'] !== inputs.SensitiveWords) {
 | 
			
		||||
          await updateOption('SensitiveWords', inputs.SensitiveWords);
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case 'quota':
 | 
			
		||||
        if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
 | 
			
		||||
          await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
 | 
			
		||||
@@ -273,6 +283,51 @@ const OperationSetting = () => {
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as="h3">
 | 
			
		||||
            屏蔽词过滤设置
 | 
			
		||||
          </Header>
 | 
			
		||||
          <Form.Group inline>
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.CheckSensitiveEnabled === 'true'}
 | 
			
		||||
              label="启用屏蔽词过滤功能"
 | 
			
		||||
              name="CheckSensitiveEnabled"
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group inline>
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.CheckSensitiveOnPromptEnabled === 'true'}
 | 
			
		||||
              label="启用prompt检查"
 | 
			
		||||
              name="CheckSensitiveOnPromptEnabled"
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.CheckSensitiveOnCompletionEnabled === 'true'}
 | 
			
		||||
              label="启用生成内容检查"
 | 
			
		||||
              name="CheckSensitiveOnCompletionEnabled"
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.StopOnSensitiveEnabled === 'true'}
 | 
			
		||||
              label="在检测到屏蔽词时,立刻停止生成,否则替换屏蔽词"
 | 
			
		||||
              name="StopOnSensitiveEnabled"
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group widths="equal">
 | 
			
		||||
            <Form.TextArea
 | 
			
		||||
              label="屏蔽词列表,一行一个屏蔽词,不需要符号分割"
 | 
			
		||||
              name="SensitiveWords"
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
 | 
			
		||||
              value={inputs.SensitiveWords}
 | 
			
		||||
              placeholder="一行一个屏蔽词"
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Button onClick={() => {
 | 
			
		||||
            submitConfig('words').then();
 | 
			
		||||
          }}>保存屏蔽词设置</Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as="h3">
 | 
			
		||||
            日志设置
 | 
			
		||||
          </Header>
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user