From 64b9d3b58c03a35d2b616c8a7e207f1b9fe807a1 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 20 Mar 2024 19:00:51 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=E7=94=9F=E6=88=90=E5=86=85=E5=AE=B9=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/str.go | 12 ++++ controller/channel-test.go | 2 +- dto/sensitive.go | 6 ++ dto/text_response.go | 4 +- relay/channel/adapter.go | 2 +- relay/channel/ali/adaptor.go | 2 +- relay/channel/baidu/adaptor.go | 2 +- relay/channel/claude/adaptor.go | 2 +- relay/channel/gemini/adaptor.go | 2 +- relay/channel/ollama/adaptor.go | 4 +- relay/channel/openai/adaptor.go | 4 +- relay/channel/openai/relay-openai.go | 101 +++++++++++++++++++-------- relay/channel/palm/adaptor.go | 2 +- relay/channel/perplexity/adaptor.go | 4 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/xunfei/adaptor.go | 6 +- relay/channel/zhipu/adaptor.go | 2 +- relay/channel/zhipu_4v/adaptor.go | 4 +- relay/common/relay_utils.go | 2 +- relay/relay-text.go | 25 +++++-- service/sensitive.go | 14 ++-- 21 files changed, 141 insertions(+), 63 deletions(-) create mode 100644 dto/sensitive.go diff --git a/common/str.go b/common/str.go index d16f7a0..ddf8375 100644 --- a/common/str.go +++ b/common/str.go @@ -36,3 +36,15 @@ func SundaySearch(text string, pattern string) bool { } return false // 如果没有找到匹配,返回-1 } + +func RemoveDuplicate(s []string) []string { + result := make([]string, 0, len(s)) + temp := map[string]struct{}{} + for _, item := range s { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + result = append(result, item) + } + } + return result +} diff --git a/controller/channel-test.go b/controller/channel-test.go index 4a2906b..2eee554 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr err := relaycommon.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } - usage, respErr := adaptor.DoResponse(c, resp, meta) + usage, respErr, _ := adaptor.DoResponse(c, resp, meta) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error } diff --git a/dto/sensitive.go b/dto/sensitive.go new file mode 100644 index 0000000..0bfbc6f --- /dev/null +++ b/dto/sensitive.go @@ -0,0 +1,6 @@ +package dto + +type SensitiveResponse struct { + SensitiveWords []string `json:"sensitive_words"` + Content string `json:"content"` +} diff --git a/dto/text_response.go b/dto/text_response.go index 81d0748..7b94ed5 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -1,9 +1,9 @@ package dto type TextResponse struct { - Choices []OpenAITextResponseChoice `json:"choices"` + Choices []*OpenAITextResponseChoice `json:"choices"` Usage `json:"usage"` - Error OpenAIError `json:"error"` + Error *OpenAIError `json:"error,omitempty"` } type OpenAITextResponseChoice struct { diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index d3886d5..935a0f0 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -15,7 +15,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) GetModelList() []string GetChannelName() string } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index bfe83db..155d5a0 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { err, usage = aliStreamHandler(c, resp) } else { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index d2571dc..d8c96aa 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { err, usage = baiduStreamHandler(c, resp) } else { diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 45efd01..ee5a4c0 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp) } else { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index a275175..5f78829 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = geminiChatStreamHandler(c, resp) diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 55edf7a..6ef2f30 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -39,13 +39,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 417dbce..d9c52f8 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -71,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index a3a2634..4409555 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -4,8 +4,11 @@ import ( "bufio" "bytes" "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" "io" + "log" "net/http" "one-api/common" "one-api/constant" @@ -18,6 +21,7 @@ import ( ) func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { + checkSensitive := constant.ShouldCheckCompletionSensitive() var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -37,11 +41,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d defer close(stopChan) defer close(dataChan) var wg sync.WaitGroup - go func() { wg.Add(1) defer wg.Done() - var streamItems []string + var streamItems []string // store stream items for scanner.Scan() { data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format @@ -50,11 +53,20 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d if data[:6] != "data: " && data[:6] != "[DONE]" { continue } + sensitive := false + if checkSensitive { + // check sensitive + sensitive, _, data = service.SensitiveWordReplace(data, constant.StopOnSensitiveEnabled) + } dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { streamItems = append(streamItems, data) } + if sensitive && constant.StopOnSensitiveEnabled { + dataChan <- "data: [DONE]" + break + } } streamResp := "[" + strings.Join(streamItems, ",") + "]" switch relayMode { @@ -112,50 +124,48 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d return nil, responseTextBuilder.String() } -func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { var textResponse dto.TextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil } - if textResponse.Error.Type != "" { + log.Printf("textResponse: %+v", textResponse) + if textResponse.Error != nil { return &dto.OpenAIErrorWithStatusCode{ - Error: textResponse.Error, + Error: *textResponse.Error, StatusCode: resp.StatusCode, - }, nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + }, nil, nil } - if textResponse.Usage.TotalTokens == 0 { + checkSensitive := constant.ShouldCheckCompletionSensitive() + sensitiveWords := make([]string, 0) + triggerSensitive := false + + if textResponse.Usage.TotalTokens == 0 || checkSensitive { completionTokens := 0 for _, choice := range textResponse.Choices { - ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive()) + stringContent := string(choice.Message.Content) + ctkm, _ := service.CountTokenText(stringContent, model, false) completionTokens += ctkm + if checkSensitive { + sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) + if sensitive { + triggerSensitive = true + msg := choice.Message + msg.Content = common.StringToByteSlice(stringContent) + choice.Message = msg + sensitiveWords = append(sensitiveWords, words...) + } + } } textResponse.Usage = dto.Usage{ PromptTokens: promptTokens, @@ -163,5 +173,36 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model TotalTokens: promptTokens + completionTokens, } } - return nil, &textResponse.Usage + + if constant.StopOnSensitiveEnabled { + + } else { + responseBody, err = json.Marshal(textResponse) + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + } + } + + if checkSensitive && triggerSensitive { + sensitiveWords = common.RemoveDuplicate(sensitiveWords) + return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{ + SensitiveWords: sensitiveWords, + } + } + return nil, &textResponse.Usage, nil } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 4f59a44..2129ee3 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 24765ff..8d056ec 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 470ec14..fa3da0c 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 79a4b12..9ab2818 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { splits := strings.Split(info.ApiKey, "|") if len(splits) != 3 { - return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil } if a.request == nil { - return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) + return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil } if info.IsStream { err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 6f2d186..69c45a8 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { err, usage = zhipuStreamHandler(c, resp) } else { diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 1b8866b..dded3c5 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index f89b8b9..71b8ba7 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open if err != nil { return } - OpenAIErrorWithStatusCode.Error = textResponse.Error + OpenAIErrorWithStatusCode.Error = *textResponse.Error return } diff --git a/relay/relay-text.go b/relay/relay-text.go index 8a38a81..b9c1e7a 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -162,12 +162,21 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.RelayErrorHandler(resp) } - usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo) if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - return openaiErr + if sensitiveResp == nil { // 如果没有敏感词检查结果 + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + return openaiErr + } else { + // 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额 + postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp) + if constant.StopOnSensitiveEnabled { // 是否直接返回错误 + return openaiErr + } + return nil + } } - postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) + postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil) return nil } @@ -243,7 +252,10 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu } } -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, + usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, + modelPrice float64, sensitiveResp *dto.SensitiveResponse) { + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens @@ -277,6 +289,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe logContent += fmt.Sprintf("(可能是上游超时)") common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota)) } else { + if sensitiveResp != nil { + logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { diff --git a/service/sensitive.go b/service/sensitive.go index 6b77849..57c667f 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -24,18 +24,21 @@ func SensitiveWordContains(text string) (bool, []string) { } // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 -func SensitiveWordReplace(text string) (bool, string) { +func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) { + text = strings.ToLower(text) m := initAc() - hits := m.MultiPatternSearch([]rune(text), false) + hits := m.MultiPatternSearch([]rune(text), returnImmediately) if len(hits) > 0 { + words := make([]string, 0) for _, hit := range hits { pos := hit.Pos word := string(hit.Word) - text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):] + text = text[:pos] + " *###* " + text[pos+len(word):] + words = append(words, word) } - return true, text + return true, words, text } - return false, text + return false, nil, text } func initAc() *goahocorasick.Machine { @@ -52,6 +55,7 @@ func readRunes() [][]rune { var dict [][]rune for _, word := range constant.SensitiveWords { + word = strings.ToLower(word) l := bytes.TrimSpace([]byte(word)) dict = append(dict, bytes.Runes(l)) }