diff --git a/controller/misc.go b/controller/misc.go index 04b2c5b..f15fa6a 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -33,6 +33,7 @@ func GetStatus(c *gin.Context) { "success": true, "message": "", "data": gin.H{ + "version": common.Version, "start_time": common.StartTime, "email_verification": common.EmailVerificationEnabled, "github_oauth": common.GitHubOAuthEnabled, @@ -63,7 +64,6 @@ func GetStatus(c *gin.Context) { "default_collapse_sidebar": common.DefaultCollapseSidebar, "payment_enabled": common.PaymentEnabled, "mj_notify_enabled": constant.MjNotifyEnabled, - "version": common.Version, }, }) return diff --git a/dto/text_response.go b/dto/text_response.go index 63a344d..16deb0d 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -1,12 +1,21 @@ package dto type TextResponseWithError struct { - Choices []OpenAITextResponseChoice `json:"choices"` + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []OpenAITextResponseChoice `json:"choices"` + Data []OpenAIEmbeddingResponseItem `json:"data"` + Model string `json:"model"` Usage `json:"usage"` Error OpenAIError `json:"error"` } type TextResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` Choices []OpenAITextResponseChoice `json:"choices"` Usage `json:"usage"` } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 6ef2f30..69a97e3 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -45,7 +45,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } return } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index d9c52f8..9e2845d 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -77,7 +77,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 349d5d5..b8b7d8d 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -124,8 +124,8 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d return nil, responseTextBuilder.String() } -func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { - var textResponseWithError dto.TextResponseWithError +func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { + var responseWithError dto.TextResponseWithError responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil @@ -134,64 +134,99 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil } - err = json.Unmarshal(responseBody, &textResponseWithError) + err = json.Unmarshal(responseBody, &responseWithError) if err != nil { log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err) return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil } - if textResponseWithError.Error.Type != "" { + if responseWithError.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - Error: textResponseWithError.Error, + Error: responseWithError.Error, StatusCode: resp.StatusCode, }, nil, nil } - textResponse := &dto.TextResponse{ - Choices: textResponseWithError.Choices, - Usage: textResponseWithError.Usage, - } - checkSensitive := constant.ShouldCheckCompletionSensitive() sensitiveWords := make([]string, 0) triggerSensitive := false - if textResponse.Usage.TotalTokens == 0 || checkSensitive { - completionTokens := 0 - for _, choice := range textResponse.Choices { - stringContent := string(choice.Message.Content) - ctkm, _, _ := service.CountTokenText(stringContent, model, false) - completionTokens += ctkm - if checkSensitive { - sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) - if sensitive { - triggerSensitive = true - msg := choice.Message - msg.Content = common.StringToByteSlice(stringContent) - choice.Message = msg - sensitiveWords = append(sensitiveWords, words...) + usage := &responseWithError.Usage + + //textResponse := &dto.TextResponse{ + // Choices: responseWithError.Choices, + // Usage: responseWithError.Usage, + //} + var doResponseBody []byte + + switch relayMode { + case relayconstant.RelayModeEmbeddings: + embeddingResponse := &dto.OpenAIEmbeddingResponse{ + Object: responseWithError.Object, + Data: responseWithError.Data, + Model: responseWithError.Model, + Usage: *usage, + } + doResponseBody, err = json.Marshal(embeddingResponse) + default: + if responseWithError.Usage.TotalTokens == 0 || checkSensitive { + completionTokens := 0 + for i, choice := range responseWithError.Choices { + stringContent := string(choice.Message.Content) + ctkm, _, _ := service.CountTokenText(stringContent, model, false) + completionTokens += ctkm + if checkSensitive { + sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) + if sensitive { + triggerSensitive = true + msg := choice.Message + msg.Content = common.StringToByteSlice(stringContent) + responseWithError.Choices[i].Message = msg + sensitiveWords = append(sensitiveWords, words...) + } } } + responseWithError.Usage = dto.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } } - textResponse.Usage = dto.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, + textResponse := &dto.TextResponse{ + Id: responseWithError.Id, + Created: responseWithError.Created, + Object: responseWithError.Object, + Choices: responseWithError.Choices, + Model: responseWithError.Model, + Usage: *usage, } + doResponseBody, err = json.Marshal(textResponse) } - if checkSensitive && constant.StopOnSensitiveEnabled && triggerSensitive { - + if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled { + sensitiveWords = common.RemoveDuplicate(sensitiveWords) + return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", + strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), + usage, &dto.SensitiveResponse{ + SensitiveWords: sensitiveWords, + } } else { - responseBody, err = json.Marshal(textResponse) // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the httpClient will be confused by the response. // For example, Postman will report error, and we cannot check the response at all. + // Copy headers for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) + // 删除任何现有的相同头部,以防止重复添加头部 + c.Writer.Header().Del(k) + for _, vv := range v { + c.Writer.Header().Add(k, vv) + } } + // reset content length + c.Writer.Header().Del("Content-Length") + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody))) c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { @@ -202,12 +237,5 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil } } - - if checkSensitive && triggerSensitive { - sensitiveWords = common.RemoveDuplicate(sensitiveWords) - return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{ - SensitiveWords: sensitiveWords, - } - } - return nil, &textResponse.Usage, nil + return nil, usage, nil } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 8d056ec..d04af1e 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -49,7 +49,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } return } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index dded3c5..c7ea903 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -50,7 +50,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } return } diff --git a/service/midjourney.go b/service/midjourney.go index 11ec5bd..4f43b52 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -185,7 +185,12 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU req = req.WithContext(ctx) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) + auth := c.Request.Header.Get("Authorization") + if auth != "" { + auth = strings.TrimPrefix(auth, "Bearer ") + auth = strings.Split(auth, "-")[0] + req.Header.Set("mj-api-secret", auth) + } defer cancel() resp, err := GetHttpClient().Do(req) if err != nil {