feat: 初步兼容敏感词过滤

This commit is contained in:
CaIon
2024-03-20 17:07:42 +08:00
parent bec21ade9d
commit 7a663d26ec
22 changed files with 293 additions and 66 deletions

View File

@@ -10,6 +10,7 @@ import (
"math"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -96,6 +97,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
promptTokens, err := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
@@ -172,16 +174,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
checkSensitive := constant.ShouldCheckPromptSensitive()
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model)
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
case relayconstant.RelayModeEmbeddings:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
default:
err = errors.New("unknown relay mode")
promptTokens = 0