feat: 初步兼容敏感词过滤

This commit is contained in:
CaIon
2024-03-20 17:07:42 +08:00
parent bec21ade9d
commit 7a663d26ec
22 changed files with 293 additions and 66 deletions

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/service"
"strings"
@@ -194,7 +195,7 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage dto.Usage
var usage *dto.Usage
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
@@ -277,13 +278,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
} else {
if usage.CompletionTokens == 0 {
usage = *service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
}
}
return nil, &usage
return nil, usage
}
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -312,7 +313,10 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
}, nil
}
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens := service.CountTokenText(claudeResponse.Completion, model)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if requestMode == RequestModeCompletion {
usage.PromptTokens = promptTokens

View File

@@ -51,7 +51,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -256,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model)
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,

View File

@@ -43,7 +43,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -75,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"one-api/service"
@@ -153,7 +154,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += service.CountTokenText(string(choice.Message.Content), model)
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive())
completionTokens += ctkm
}
textResponse.Usage = dto.Usage{
PromptTokens: promptTokens,

View File

@@ -43,7 +43,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -156,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,

View File

@@ -47,7 +47,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -57,7 +57,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = tencentHandler(c, resp)
}

View File

@@ -48,7 +48,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -10,6 +10,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -62,8 +63,16 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
}
}
var err error
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if strings.HasPrefix(audioRequest.Model, "tts-1") {
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}
preConsumedTokens = promptTokens
}
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
@@ -161,12 +170,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
go func() {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
quota := 0
var promptTokens = 0
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
promptTokens = quota
quota = promptTokens
} else {
quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
@@ -208,6 +215,10 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
contains, words := service.SensitiveWordContains(audioResponse.Text)
if contains {
return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest)
}
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))

View File

@@ -10,6 +10,7 @@ import (
"math"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -96,6 +97,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
promptTokens, err := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
@@ -172,16 +174,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
checkSensitive := constant.ShouldCheckPromptSensitive()
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model)
promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
case relayconstant.RelayModeEmbeddings:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
default:
err = errors.New("unknown relay mode")
promptTokens = 0