From 44a8ade4bac9a35af0bd683ec9f0e594c6bca392 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 29 Mar 2024 22:20:14 +0800 Subject: [PATCH] fix: remove sensitive check on completion (close #157) --- constant/sensitive.go | 9 +- controller/channel-test.go | 2 +- dto/text_response.go | 6 + model/option.go | 6 +- relay/channel/adapter.go | 2 +- relay/channel/ali/adaptor.go | 2 +- relay/channel/baidu/adaptor.go | 2 +- relay/channel/claude/adaptor.go | 2 +- relay/channel/claude/relay-claude.go | 3 +- relay/channel/gemini/adaptor.go | 2 +- relay/channel/gemini/relay-gemini.go | 3 +- relay/channel/ollama/adaptor.go | 6 +- relay/channel/ollama/relay-ollama.go | 16 +-- relay/channel/openai/adaptor.go | 4 +- relay/channel/openai/relay-openai.go | 151 +++++++------------------ relay/channel/palm/adaptor.go | 2 +- relay/channel/palm/relay-palm.go | 3 +- relay/channel/perplexity/adaptor.go | 4 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/xunfei/adaptor.go | 6 +- relay/channel/zhipu/adaptor.go | 2 +- relay/channel/zhipu_4v/adaptor.go | 4 +- relay/relay-audio.go | 2 +- relay/relay-text.go | 25 ++-- web/src/components/OperationSetting.js | 28 ++--- 25 files changed, 107 insertions(+), 187 deletions(-) diff --git a/constant/sensitive.go b/constant/sensitive.go index 8d8e15f..5297560 100644 --- a/constant/sensitive.go +++ b/constant/sensitive.go @@ -4,7 +4,8 @@ import "strings" var CheckSensitiveEnabled = true var CheckSensitiveOnPromptEnabled = true -var CheckSensitiveOnCompletionEnabled = true + +//var CheckSensitiveOnCompletionEnabled = true // StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词 var StopOnSensitiveEnabled = true @@ -37,6 +38,6 @@ func ShouldCheckPromptSensitive() bool { return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled } -func ShouldCheckCompletionSensitive() bool { - return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled -} +//func ShouldCheckCompletionSensitive() bool { +// return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled +//} diff --git a/controller/channel-test.go b/controller/channel-test.go index 6d64a24..a4dcfe9 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr err := relaycommon.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } - usage, respErr, _ := adaptor.DoResponse(c, resp, meta) + usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error } diff --git a/dto/text_response.go b/dto/text_response.go index 16deb0d..98275fe 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -11,6 +11,12 @@ type TextResponseWithError struct { Error OpenAIError `json:"error"` } +type SimpleResponse struct { + Usage `json:"usage"` + Error OpenAIError `json:"error"` + Choices []OpenAITextResponseChoice `json:"choices"` +} + type TextResponse struct { Id string `json:"id"` Object string `json:"object"` diff --git a/model/option.go b/model/option.go index 2ccfe03..2bfa22a 100644 --- a/model/option.go +++ b/model/option.go @@ -93,7 +93,7 @@ func InitOptionMap() { common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled) - common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) + //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled) common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength) @@ -196,8 +196,8 @@ func updateOptionMap(key string, value string) (err error) { constant.CheckSensitiveEnabled = boolValue case "CheckSensitiveOnPromptEnabled": constant.CheckSensitiveOnPromptEnabled = boolValue - case "CheckSensitiveOnCompletionEnabled": - constant.CheckSensitiveOnCompletionEnabled = boolValue + //case "CheckSensitiveOnCompletionEnabled": + // constant.CheckSensitiveOnCompletionEnabled = boolValue case "StopOnSensitiveEnabled": constant.StopOnSensitiveEnabled = boolValue case "SMTPSSLEnabled": diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 935a0f0..d3886d5 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -15,7 +15,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string GetChannelName() string } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 155d5a0..bfe83db 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = aliStreamHandler(c, resp) } else { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index d8c96aa..d2571dc 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = baiduStreamHandler(c, resp) } else { diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index ee5a4c0..45efd01 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp) } else { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index f1a3c39..4de8dc0 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" "one-api/service" "strings" @@ -317,7 +316,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT }, nil } fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive()) + completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 5f78829..a275175 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = geminiChatStreamHandler(c, resp) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 31badd8..4a10a73 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -257,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive()) + completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index f66d9a9..4e1fd33 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -49,16 +49,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { - err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } } return diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index fa5f818..828ddea 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -45,19 +45,19 @@ func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbedding } } -func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { +func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var ollamaEmbeddingResponse OllamaEmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data = append(data, dto.OpenAIEmbeddingResponseItem{ @@ -77,7 +77,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in } doResponseBody, err := json.Marshal(embeddingResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. @@ -98,11 +98,11 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - return nil, usage, nil + return nil, usage } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 7dc591e..cab6a64 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -69,13 +69,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 39127de..fe5cd48 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -4,14 +4,10 @@ import ( "bufio" "bytes" "encoding/json" - "errors" - "fmt" "github.com/gin-gonic/gin" "io" - "log" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" "one-api/service" @@ -21,7 +17,7 @@ import ( ) func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { - checkSensitive := constant.ShouldCheckCompletionSensitive() + //checkSensitive := constant.ShouldCheckCompletionSensitive() var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -53,20 +49,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d if data[:6] != "data: " && data[:6] != "[DONE]" { continue } - sensitive := false - if checkSensitive { - // check sensitive - sensitive, _, data = service.SensitiveWordReplace(data, false) - } dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { streamItems = append(streamItems, data) } - if sensitive && constant.StopOnSensitiveEnabled { - dataChan <- "data: [DONE]" - break - } } streamResp := "[" + strings.Join(streamItems, ",") + "]" switch relayMode { @@ -142,118 +129,56 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d return nil, responseTextBuilder.String() } -func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { - var responseWithError dto.TextResponseWithError +func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var simpleResponse dto.SimpleResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - err = json.Unmarshal(responseBody, &responseWithError) + err = json.Unmarshal(responseBody, &simpleResponse) if err != nil { - log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err) - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - if responseWithError.Error.Type != "" { + if simpleResponse.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - Error: responseWithError.Error, + Error: simpleResponse.Error, StatusCode: resp.StatusCode, - }, nil, nil + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - checkSensitive := constant.ShouldCheckCompletionSensitive() - sensitiveWords := make([]string, 0) - triggerSensitive := false - - usage := &responseWithError.Usage - - //textResponse := &dto.TextResponse{ - // Choices: responseWithError.Choices, - // Usage: responseWithError.Usage, - //} - var doResponseBody []byte - - switch relayMode { - case relayconstant.RelayModeEmbeddings: - embeddingResponse := &dto.OpenAIEmbeddingResponse{ - Object: responseWithError.Object, - Data: responseWithError.Data, - Model: responseWithError.Model, - Usage: *usage, + if simpleResponse.Usage.TotalTokens == 0 { + completionTokens := 0 + for _, choice := range simpleResponse.Choices { + ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false) + completionTokens += ctkm } - doResponseBody, err = json.Marshal(embeddingResponse) - default: - if responseWithError.Usage.TotalTokens == 0 || checkSensitive { - completionTokens := 0 - for i, choice := range responseWithError.Choices { - stringContent := string(choice.Message.Content) - ctkm, _, _ := service.CountTokenText(stringContent, model, false) - completionTokens += ctkm - if checkSensitive { - sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) - if sensitive { - triggerSensitive = true - msg := choice.Message - msg.Content = common.StringToByteSlice(stringContent) - responseWithError.Choices[i].Message = msg - sensitiveWords = append(sensitiveWords, words...) - } - } - } - responseWithError.Usage = dto.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - } - textResponse := &dto.TextResponse{ - Id: responseWithError.Id, - Created: responseWithError.Created, - Object: responseWithError.Object, - Choices: responseWithError.Choices, - Model: responseWithError.Model, - Usage: *usage, - } - doResponseBody, err = json.Marshal(textResponse) - } - - if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled { - sensitiveWords = common.RemoveDuplicate(sensitiveWords) - return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", - strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), - usage, &dto.SensitiveResponse{ - SensitiveWords: sensitiveWords, - } - } else { - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - // Copy headers - for k, v := range resp.Header { - // 删除任何现有的相同头部,以防止重复添加头部 - c.Writer.Header().Del(k) - for _, vv := range v { - c.Writer.Header().Add(k, vv) - } - } - // reset content length - c.Writer.Header().Del("Content-Length") - c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody))) - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + simpleResponse.Usage = dto.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, } } - return nil, usage, nil + return nil, &simpleResponse.Usage } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 2129ee3..4f59a44 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 4028269..3a7d4fa 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -157,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive()) + completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index d04af1e..24765ff 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index fa3da0c..470ec14 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 9ab2818..79a4b12 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { splits := strings.Split(info.ApiKey, "|") if len(splits) != 3 { - return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil + return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) } if a.request == nil { - return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil + return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) } if info.IsStream { err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 69c45a8..6f2d186 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = zhipuStreamHandler(c, resp) } else { diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index c7ea903..1b8866b 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 1c0f868..d4458ce 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if strings.HasPrefix(audioRequest.Model, "tts-1") { quota = promptTokens } else { - quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive()) + quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false) } quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { diff --git a/relay/relay-text.go b/relay/relay-text.go index ec64e04..ff653ff 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -165,21 +165,12 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode) } - usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) if openaiErr != nil { - if sensitiveResp == nil { // 如果没有敏感词检查结果 - returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - return openaiErr - } else { - // 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额 - postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp) - if constant.StopOnSensitiveEnabled { // 是否直接返回错误 - return openaiErr - } - return nil - } + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + return openaiErr } - postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil) + postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) return nil } @@ -258,7 +249,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, sensitiveResp *dto.SensitiveResponse) { + modelPrice float64) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -293,9 +284,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe logContent += fmt.Sprintf("(可能是上游超时)") common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota)) } else { - if sensitiveResp != nil { - logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) - } + //if sensitiveResp != nil { + // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + //} quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 3019906..f42fe57 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -330,21 +330,21 @@ const OperationSetting = () => { name='CheckSensitiveOnPromptEnabled' onChange={handleInputChange} /> - - - - + {/**/} + {/**/} + {/* */} + {/**/} {/**/} {/*