Compare commits

...

3 Commits

Author SHA1 Message Date
CaIon
222a55387d fix: fix SensitiveWords error 2024-03-21 14:29:56 +08:00
CaIon
d7e25e1604 fix: fix SensitiveWords load error 2024-03-20 23:58:42 +08:00
CaIon
c7658b70d1 fix: 敏感词库为空时,不检测 2024-03-20 22:33:22 +08:00
5 changed files with 37 additions and 14 deletions

View File

@@ -23,7 +23,14 @@ func SensitiveWordsToString() string {
}
func SensitiveWordsFromString(s string) {
SensitiveWords = strings.Split(s, "\n")
SensitiveWords = []string{}
sw := strings.Split(s, "\n")
for _, w := range sw {
w = strings.TrimSpace(w)
if w != "" {
SensitiveWords = append(SensitiveWords, w)
}
}
}
func ShouldCheckPromptSensitive() bool {

View File

@@ -1,9 +1,14 @@
package dto
type TextResponse struct {
Choices []*OpenAITextResponseChoice `json:"choices"`
type TextResponseWithError struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error *OpenAIError `json:"error,omitempty"`
}
type OpenAITextResponseChoice struct {

View File

@@ -125,7 +125,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var textResponse dto.TextResponse
var textResponseWithError dto.TextResponseWithError
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
@@ -134,18 +134,23 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = json.Unmarshal(responseBody, &textResponse)
err = json.Unmarshal(responseBody, &textResponseWithError)
if err != nil {
log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
}
log.Printf("textResponse: %+v", textResponse)
if textResponse.Error != nil {
if textResponseWithError.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: *textResponse.Error,
Error: textResponseWithError.Error,
StatusCode: resp.StatusCode,
}, nil, nil
}
textResponse := &dto.TextResponse{
Choices: textResponseWithError.Choices,
Usage: textResponseWithError.Usage,
}
checkSensitive := constant.ShouldCheckCompletionSensitive()
sensitiveWords := make([]string, 0)
triggerSensitive := false
@@ -174,7 +179,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
}
if constant.StopOnSensitiveEnabled {
if checkSensitive && constant.StopOnSensitiveEnabled && triggerSensitive {
} else {
responseBody, err = json.Marshal(textResponse)
@@ -200,7 +205,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if checkSensitive && triggerSensitive {
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
SensitiveWords: sensitiveWords,
}
}

View File

@@ -35,12 +35,12 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
if err != nil {
return
}
var textResponse dto.TextResponse
var textResponse dto.TextResponseWithError
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return
}
OpenAIErrorWithStatusCode.Error = *textResponse.Error
OpenAIErrorWithStatusCode.Error = textResponse.Error
return
}

View File

@@ -10,6 +10,9 @@ import (
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
func SensitiveWordContains(text string) (bool, []string) {
if len(constant.SensitiveWords) == 0 {
return false, nil
}
checkText := strings.ToLower(text)
// 构建一个AC自动机
m := initAc()
@@ -26,6 +29,9 @@ func SensitiveWordContains(text string) (bool, []string) {
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
if len(constant.SensitiveWords) == 0 {
return false, nil, text
}
checkText := strings.ToLower(text)
m := initAc()
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
@@ -34,7 +40,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
for _, hit := range hits {
pos := hit.Pos
word := string(hit.Word)
text = text[:pos] + "*###*" + text[pos+len(word):]
text = text[:pos] + "**###**" + text[pos+len(word):]
words = append(words, word)
}
return true, words, text