merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-03-22 18:09:13 +08:00
commit 71d60eeef7
8 changed files with 90 additions and 48 deletions

View File

@ -33,6 +33,7 @@ func GetStatus(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": gin.H{
"version": common.Version,
"start_time": common.StartTime, "start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled, "email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": common.GitHubOAuthEnabled,
@ -63,7 +64,6 @@ func GetStatus(c *gin.Context) {
"default_collapse_sidebar": common.DefaultCollapseSidebar, "default_collapse_sidebar": common.DefaultCollapseSidebar,
"payment_enabled": common.PaymentEnabled, "payment_enabled": common.PaymentEnabled,
"mj_notify_enabled": constant.MjNotifyEnabled, "mj_notify_enabled": constant.MjNotifyEnabled,
"version": common.Version,
}, },
}) })
return return

View File

@ -1,12 +1,21 @@
package dto package dto
type TextResponseWithError struct { type TextResponseWithError struct {
Choices []OpenAITextResponseChoice `json:"choices"` Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"` Usage `json:"usage"`
Error OpenAIError `json:"error"` Error OpenAIError `json:"error"`
} }
type TextResponse struct { type TextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAITextResponseChoice `json:"choices"` Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"` Usage `json:"usage"`
} }

View File

@ -45,7 +45,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} }
return return
} }

View File

@ -77,7 +77,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} }
return return
} }

View File

@ -124,8 +124,8 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
return nil, responseTextBuilder.String() return nil, responseTextBuilder.String()
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var textResponseWithError dto.TextResponseWithError var responseWithError dto.TextResponseWithError
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
@ -134,64 +134,99 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
} }
err = json.Unmarshal(responseBody, &textResponseWithError) err = json.Unmarshal(responseBody, &responseWithError)
if err != nil { if err != nil {
log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err) log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
} }
if textResponseWithError.Error.Type != "" { if responseWithError.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
Error: textResponseWithError.Error, Error: responseWithError.Error,
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil, nil }, nil, nil
} }
textResponse := &dto.TextResponse{
Choices: textResponseWithError.Choices,
Usage: textResponseWithError.Usage,
}
checkSensitive := constant.ShouldCheckCompletionSensitive() checkSensitive := constant.ShouldCheckCompletionSensitive()
sensitiveWords := make([]string, 0) sensitiveWords := make([]string, 0)
triggerSensitive := false triggerSensitive := false
if textResponse.Usage.TotalTokens == 0 || checkSensitive { usage := &responseWithError.Usage
completionTokens := 0
for _, choice := range textResponse.Choices { //textResponse := &dto.TextResponse{
stringContent := string(choice.Message.Content) // Choices: responseWithError.Choices,
ctkm, _, _ := service.CountTokenText(stringContent, model, false) // Usage: responseWithError.Usage,
completionTokens += ctkm //}
if checkSensitive { var doResponseBody []byte
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
if sensitive { switch relayMode {
triggerSensitive = true case relayconstant.RelayModeEmbeddings:
msg := choice.Message embeddingResponse := &dto.OpenAIEmbeddingResponse{
msg.Content = common.StringToByteSlice(stringContent) Object: responseWithError.Object,
choice.Message = msg Data: responseWithError.Data,
sensitiveWords = append(sensitiveWords, words...) Model: responseWithError.Model,
Usage: *usage,
}
doResponseBody, err = json.Marshal(embeddingResponse)
default:
if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
completionTokens := 0
for i, choice := range responseWithError.Choices {
stringContent := string(choice.Message.Content)
ctkm, _, _ := service.CountTokenText(stringContent, model, false)
completionTokens += ctkm
if checkSensitive {
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
if sensitive {
triggerSensitive = true
msg := choice.Message
msg.Content = common.StringToByteSlice(stringContent)
responseWithError.Choices[i].Message = msg
sensitiveWords = append(sensitiveWords, words...)
}
} }
} }
responseWithError.Usage = dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
} }
textResponse.Usage = dto.Usage{ textResponse := &dto.TextResponse{
PromptTokens: promptTokens, Id: responseWithError.Id,
CompletionTokens: completionTokens, Created: responseWithError.Created,
TotalTokens: promptTokens + completionTokens, Object: responseWithError.Object,
Choices: responseWithError.Choices,
Model: responseWithError.Model,
Usage: *usage,
} }
doResponseBody, err = json.Marshal(textResponse)
} }
if checkSensitive && constant.StopOnSensitiveEnabled && triggerSensitive { if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
usage, &dto.SensitiveResponse{
SensitiveWords: sensitiveWords,
}
} else { } else {
responseBody, err = json.Marshal(textResponse)
// Reset response body // Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail. // We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set. // And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response. // So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all. // For example, Postman will report error, and we cannot check the response at all.
// Copy headers
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) // 删除任何现有的相同头部,以防止重复添加头部
c.Writer.Header().Del(k)
for _, vv := range v {
c.Writer.Header().Add(k, vv)
}
} }
// reset content length
c.Writer.Header().Del("Content-Length")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
@ -202,12 +237,5 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
} }
} }
return nil, usage, nil
if checkSensitive && triggerSensitive {
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
SensitiveWords: sensitiveWords,
}
}
return nil, &textResponse.Usage, nil
} }

View File

@ -49,7 +49,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} }
return return
} }

View File

@ -50,7 +50,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} }
return return
} }

View File

@ -185,7 +185,12 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) auth := c.Request.Header.Get("Authorization")
if auth != "" {
auth = strings.TrimPrefix(auth, "Bearer ")
auth = strings.Split(auth, "-")[0]
req.Header.Set("mj-api-secret", auth)
}
defer cancel() defer cancel()
resp, err := GetHttpClient().Do(req) resp, err := GetHttpClient().Do(req)
if err != nil { if err != nil {