diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 349d5d5..3698a70 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -157,7 +157,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if textResponse.Usage.TotalTokens == 0 || checkSensitive { completionTokens := 0 - for _, choice := range textResponse.Choices { + for i, choice := range textResponse.Choices { stringContent := string(choice.Message.Content) ctkm, _, _ := service.CountTokenText(stringContent, model, false) completionTokens += ctkm @@ -167,7 +167,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model triggerSensitive = true msg := choice.Message msg.Content = common.StringToByteSlice(stringContent) - choice.Message = msg + textResponse.Choices[i].Message = msg sensitiveWords = append(sensitiveWords, words...) } } @@ -179,8 +179,13 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model } } - if checkSensitive && constant.StopOnSensitiveEnabled && triggerSensitive { - + if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled { + sensitiveWords = common.RemoveDuplicate(sensitiveWords) + return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", + strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), + &textResponse.Usage, &dto.SensitiveResponse{ + SensitiveWords: sensitiveWords, + } } else { responseBody, err = json.Marshal(textResponse) // Reset response body @@ -202,12 +207,5 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil } } - - if checkSensitive && triggerSensitive { - sensitiveWords = common.RemoveDuplicate(sensitiveWords) - return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{ - SensitiveWords: sensitiveWords, - } - } return nil, &textResponse.Usage, nil }