From 222a55387d775d5623e5d27db1f3f3d80ab276dd Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 21 Mar 2024 14:29:56 +0800 Subject: [PATCH] fix: fix SensitiveWords error --- dto/text_response.go | 11 ++++++++--- relay/channel/openai/relay-openai.go | 19 ++++++++++++------- relay/common/relay_utils.go | 4 ++-- service/sensitive.go | 2 +- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/dto/text_response.go b/dto/text_response.go index 7b94ed5..63a344d 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -1,9 +1,14 @@ package dto -type TextResponse struct { - Choices []*OpenAITextResponseChoice `json:"choices"` +type TextResponseWithError struct { + Choices []OpenAITextResponseChoice `json:"choices"` + Usage `json:"usage"` + Error OpenAIError `json:"error"` +} + +type TextResponse struct { + Choices []OpenAITextResponseChoice `json:"choices"` Usage `json:"usage"` - Error *OpenAIError `json:"error,omitempty"` } type OpenAITextResponseChoice struct { diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 7e36861..349d5d5 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -125,7 +125,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d } func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { - var textResponse dto.TextResponse + var textResponseWithError dto.TextResponseWithError responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil @@ -134,18 +134,23 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil } - err = json.Unmarshal(responseBody, &textResponse) + err = json.Unmarshal(responseBody, &textResponseWithError) if err != nil { + log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err) return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil } - log.Printf("textResponse: %+v", textResponse) - if textResponse.Error != nil { + if textResponseWithError.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - Error: *textResponse.Error, + Error: textResponseWithError.Error, StatusCode: resp.StatusCode, }, nil, nil } + textResponse := &dto.TextResponse{ + Choices: textResponseWithError.Choices, + Usage: textResponseWithError.Usage, + } + checkSensitive := constant.ShouldCheckCompletionSensitive() sensitiveWords := make([]string, 0) triggerSensitive := false @@ -174,7 +179,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model } } - if constant.StopOnSensitiveEnabled { + if checkSensitive && constant.StopOnSensitiveEnabled && triggerSensitive { } else { responseBody, err = json.Marshal(textResponse) @@ -200,7 +205,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if checkSensitive && triggerSensitive { sensitiveWords = common.RemoveDuplicate(sensitiveWords) - return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{ + return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{ SensitiveWords: sensitiveWords, } } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 71b8ba7..726d22f 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -35,12 +35,12 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open if err != nil { return } - var textResponse dto.TextResponse + var textResponse dto.TextResponseWithError err = json.Unmarshal(responseBody, &textResponse) if err != nil { return } - OpenAIErrorWithStatusCode.Error = *textResponse.Error + OpenAIErrorWithStatusCode.Error = textResponse.Error return } diff --git a/service/sensitive.go b/service/sensitive.go index dbb0887..51621c3 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -40,7 +40,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, for _, hit := range hits { pos := hit.Pos word := string(hit.Word) - text = text[:pos] + "*###*" + text[pos+len(word):] + text = text[:pos] + "**###**" + text[pos+len(word):] words = append(words, word) } return true, words, text