Merge branch 'Calcium-Ion:main' into main

This commit is contained in:
Maple Gao 2024-03-23 01:10:39 +08:00 committed by GitHub
commit a825699e9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 101 additions and 49 deletions

View File

@ -33,6 +33,7 @@ func GetStatus(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": gin.H{
"version": common.Version,
"start_time": common.StartTime, "start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled, "email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": common.GitHubOAuthEnabled,

View File

@ -77,7 +77,7 @@ func RequestEpay(c *gin.Context) {
callBackAddress := service.GetCallbackAddress() callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(common.ServerAddress + "/log") returnUrl, _ := url.Parse(common.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := strconv.FormatInt(time.Now().Unix(), 10) tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
client := GetEpayClient() client := GetEpayClient()
if client == nil { if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"}) c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})

View File

@ -1,9 +1,23 @@
package dto package dto
type TextResponse struct { type TextResponseWithError struct {
Choices []*OpenAITextResponseChoice `json:"choices"` Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type TextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"` Usage `json:"usage"`
Error *OpenAIError `json:"error,omitempty"`
} }
type OpenAITextResponseChoice struct { type OpenAITextResponseChoice struct {

View File

@ -45,7 +45,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} }
return return
} }

View File

@ -78,7 +78,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} }
return return
} }

View File

@ -124,8 +124,8 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
return nil, responseTextBuilder.String() return nil, responseTextBuilder.String()
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var textResponse dto.TextResponse var responseWithError dto.TextResponseWithError
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
@ -134,14 +134,14 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
} }
err = json.Unmarshal(responseBody, &textResponse) err = json.Unmarshal(responseBody, &responseWithError)
if err != nil { if err != nil {
log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
} }
log.Printf("textResponse: %+v", textResponse) if responseWithError.Error.Type != "" {
if textResponse.Error != nil {
return &dto.OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
Error: *textResponse.Error, Error: responseWithError.Error,
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil, nil }, nil, nil
} }
@ -150,43 +150,83 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
sensitiveWords := make([]string, 0) sensitiveWords := make([]string, 0)
triggerSensitive := false triggerSensitive := false
if textResponse.Usage.TotalTokens == 0 || checkSensitive { usage := &responseWithError.Usage
completionTokens := 0
for _, choice := range textResponse.Choices { //textResponse := &dto.TextResponse{
stringContent := string(choice.Message.Content) // Choices: responseWithError.Choices,
ctkm, _, _ := service.CountTokenText(stringContent, model, false) // Usage: responseWithError.Usage,
completionTokens += ctkm //}
if checkSensitive { var doResponseBody []byte
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
if sensitive { switch relayMode {
triggerSensitive = true case relayconstant.RelayModeEmbeddings:
msg := choice.Message embeddingResponse := &dto.OpenAIEmbeddingResponse{
msg.Content = common.StringToByteSlice(stringContent) Object: responseWithError.Object,
choice.Message = msg Data: responseWithError.Data,
sensitiveWords = append(sensitiveWords, words...) Model: responseWithError.Model,
Usage: *usage,
}
doResponseBody, err = json.Marshal(embeddingResponse)
default:
if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
completionTokens := 0
for i, choice := range responseWithError.Choices {
stringContent := string(choice.Message.Content)
ctkm, _, _ := service.CountTokenText(stringContent, model, false)
completionTokens += ctkm
if checkSensitive {
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
if sensitive {
triggerSensitive = true
msg := choice.Message
msg.Content = common.StringToByteSlice(stringContent)
responseWithError.Choices[i].Message = msg
sensitiveWords = append(sensitiveWords, words...)
}
} }
} }
responseWithError.Usage = dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
} }
textResponse.Usage = dto.Usage{ textResponse := &dto.TextResponse{
PromptTokens: promptTokens, Id: responseWithError.Id,
CompletionTokens: completionTokens, Created: responseWithError.Created,
TotalTokens: promptTokens + completionTokens, Object: responseWithError.Object,
Choices: responseWithError.Choices,
Model: responseWithError.Model,
Usage: *usage,
} }
doResponseBody, err = json.Marshal(textResponse)
} }
if constant.StopOnSensitiveEnabled { if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
usage, &dto.SensitiveResponse{
SensitiveWords: sensitiveWords,
}
} else { } else {
responseBody, err = json.Marshal(textResponse)
// Reset response body // Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail. // We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set. // And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response. // So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all. // For example, Postman will report error, and we cannot check the response at all.
// Copy headers
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) // 删除任何现有的相同头部,以防止重复添加头部
c.Writer.Header().Del(k)
for _, vv := range v {
c.Writer.Header().Add(k, vv)
}
} }
// reset content length
c.Writer.Header().Del("Content-Length")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
@ -197,12 +237,5 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
} }
} }
return nil, usage, nil
if checkSensitive && triggerSensitive {
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
SensitiveWords: sensitiveWords,
}
}
return nil, &textResponse.Usage, nil
} }

View File

@ -49,7 +49,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} }
return return
} }

View File

@ -50,7 +50,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} }
return return
} }

View File

@ -35,12 +35,12 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
if err != nil { if err != nil {
return return
} }
var textResponse dto.TextResponse var textResponse dto.TextResponseWithError
err = json.Unmarshal(responseBody, &textResponse) err = json.Unmarshal(responseBody, &textResponse)
if err != nil { if err != nil {
return return
} }
OpenAIErrorWithStatusCode.Error = *textResponse.Error OpenAIErrorWithStatusCode.Error = textResponse.Error
return return
} }

View File

@ -185,7 +185,11 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) auth := c.Request.Header.Get("Authorization")
if auth != "" {
auth = strings.TrimPrefix(auth, "Bearer ")
req.Header.Set("mj-api-secret", auth)
}
defer cancel() defer cancel()
resp, err := GetHttpClient().Do(req) resp, err := GetHttpClient().Do(req)
if err != nil { if err != nil {

View File

@ -40,7 +40,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
for _, hit := range hits { for _, hit := range hits {
pos := hit.Pos pos := hit.Pos
word := string(hit.Word) word := string(hit.Word)
text = text[:pos] + "*###*" + text[pos+len(word):] text = text[:pos] + "**###**" + text[pos+len(word):]
words = append(words, word) words = append(words, word)
} }
return true, words, text return true, words, text