diff --git a/dto/text_response.go b/dto/text_response.go index 63a344d..4ef06dd 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -1,12 +1,16 @@ package dto type TextResponseWithError struct { - Choices []OpenAITextResponseChoice `json:"choices"` + Choices []OpenAITextResponseChoice `json:"choices"` + Object string `json:"object"` + Data []OpenAIEmbeddingResponseItem `json:"data"` + Model string `json:"model"` Usage `json:"usage"` Error OpenAIError `json:"error"` } type TextResponse struct { + Model string `json:"model"` Choices []OpenAITextResponseChoice `json:"choices"` Usage `json:"usage"` } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 6ef2f30..69a97e3 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -45,7 +45,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } return } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index d9c52f8..9e2845d 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -77,7 +77,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 3698a70..bbb17ef 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -124,8 +124,8 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d return nil, responseTextBuilder.String() } -func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { - var textResponseWithError dto.TextResponseWithError +func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { + var responseWithError dto.TextResponseWithError responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil @@ -134,62 +134,81 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil } - err = json.Unmarshal(responseBody, &textResponseWithError) + err = json.Unmarshal(responseBody, &responseWithError) if err != nil { log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err) return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil } - if textResponseWithError.Error.Type != "" { + if responseWithError.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - Error: textResponseWithError.Error, + Error: responseWithError.Error, StatusCode: resp.StatusCode, }, nil, nil } - textResponse := &dto.TextResponse{ - Choices: textResponseWithError.Choices, - Usage: textResponseWithError.Usage, - } - checkSensitive := constant.ShouldCheckCompletionSensitive() sensitiveWords := make([]string, 0) triggerSensitive := false - if textResponse.Usage.TotalTokens == 0 || checkSensitive { - completionTokens := 0 - for i, choice := range textResponse.Choices { - stringContent := string(choice.Message.Content) - ctkm, _, _ := service.CountTokenText(stringContent, model, false) - completionTokens += ctkm - if checkSensitive { - sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) - if sensitive { - triggerSensitive = true - msg := choice.Message - msg.Content = common.StringToByteSlice(stringContent) - textResponse.Choices[i].Message = msg - sensitiveWords = append(sensitiveWords, words...) + usage := &responseWithError.Usage + + //textResponse := &dto.TextResponse{ + // Choices: responseWithError.Choices, + // Usage: responseWithError.Usage, + //} + var doResponseBody []byte + + switch relayMode { + case relayconstant.RelayModeEmbeddings: + embeddingResponse := &dto.OpenAIEmbeddingResponse{ + Object: responseWithError.Object, + Data: responseWithError.Data, + Model: responseWithError.Model, + Usage: *usage, + } + doResponseBody, err = json.Marshal(embeddingResponse) + default: + if responseWithError.Usage.TotalTokens == 0 || checkSensitive { + completionTokens := 0 + for i, choice := range responseWithError.Choices { + stringContent := string(choice.Message.Content) + ctkm, _, _ := service.CountTokenText(stringContent, model, false) + completionTokens += ctkm + if checkSensitive { + sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) + if sensitive { + triggerSensitive = true + msg := choice.Message + msg.Content = common.StringToByteSlice(stringContent) + responseWithError.Choices[i].Message = msg + sensitiveWords = append(sensitiveWords, words...) + } } } + responseWithError.Usage = dto.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } } - textResponse.Usage = dto.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, + textResponse := &dto.TextResponse{ + Choices: responseWithError.Choices, + Model: responseWithError.Model, + Usage: *usage, } + doResponseBody, err = json.Marshal(textResponse) } if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled { sensitiveWords = common.RemoveDuplicate(sensitiveWords) return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), - &textResponse.Usage, &dto.SensitiveResponse{ + usage, &dto.SensitiveResponse{ SensitiveWords: sensitiveWords, } } else { - responseBody, err = json.Marshal(textResponse) // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the httpClient will be confused by the response. @@ -207,5 +226,5 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil } } - return nil, &textResponse.Usage, nil + return nil, usage, nil } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 8d056ec..d04af1e 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -49,7 +49,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } return } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index dded3c5..c7ea903 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -50,7 +50,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } return }