diff --git a/common/constants.go b/common/constants.go index a44a2af..0e4192a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -77,6 +77,7 @@ var LogConsumeEnabled = true var SMTPServer = "" var SMTPPort = 587 +var SMTPSSLEnabled = false var SMTPAccount = "" var SMTPFrom = "" var SMTPToken = "" diff --git a/common/email.go b/common/email.go index 5d3ef0d..13345d8 100644 --- a/common/email.go +++ b/common/email.go @@ -24,7 +24,7 @@ func SendEmail(subject string, receiver string, content string) error { addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) to := strings.Split(receiver, ";") var err error - if SMTPPort == 465 { + if SMTPPort == 465 || SMTPSSLEnabled { tlsConfig := &tls.Config{ InsecureSkipVerify: true, ServerName: SMTPServer, diff --git a/constant/sensitive.go b/constant/sensitive.go index 8d8e15f..5297560 100644 --- a/constant/sensitive.go +++ b/constant/sensitive.go @@ -4,7 +4,8 @@ import "strings" var CheckSensitiveEnabled = true var CheckSensitiveOnPromptEnabled = true -var CheckSensitiveOnCompletionEnabled = true + +//var CheckSensitiveOnCompletionEnabled = true // StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词 var StopOnSensitiveEnabled = true @@ -37,6 +38,6 @@ func ShouldCheckPromptSensitive() bool { return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled } -func ShouldCheckCompletionSensitive() bool { - return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled -} +//func ShouldCheckCompletionSensitive() bool { +// return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled +//} diff --git a/controller/channel-test.go b/controller/channel-test.go index 6d64a24..a4dcfe9 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr err := relaycommon.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } - usage, respErr, _ := adaptor.DoResponse(c, resp, meta) + usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error } diff --git a/controller/midjourney.go b/controller/midjourney.go index 41db4bf..b5b832b 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -147,7 +147,7 @@ func UpdateMidjourneyTaskBulk() { task.Buttons = string(buttonStr) } - if task.Progress != "100%" && responseItem.FailReason != "" { + if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" err = model.CacheUpdateUserQuota(task.UserId) diff --git a/dto/text_response.go b/dto/text_response.go index 16deb0d..98275fe 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -11,6 +11,12 @@ type TextResponseWithError struct { Error OpenAIError `json:"error"` } +type SimpleResponse struct { + Usage `json:"usage"` + Error OpenAIError `json:"error"` + Choices []OpenAITextResponseChoice `json:"choices"` +} + type TextResponse struct { Id string `json:"id"` Object string `json:"object"` diff --git a/middleware/distributor.go b/middleware/distributor.go index a5e40b0..10696a9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -44,7 +44,7 @@ func Distribute() func(c *gin.Context) { // Select a channel for the user var modelRequest ModelRequest var err error - if strings.HasPrefix(c.Request.URL.Path, "/mj") { + if strings.Contains(c.Request.URL.Path, "/mj/") { relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || diff --git a/model/option.go b/model/option.go index c260680..2a32c72 100644 --- a/model/option.go +++ b/model/option.go @@ -52,6 +52,7 @@ func InitOptionMap() { common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) common.OptionMap["SMTPAccount"] = "" common.OptionMap["SMTPToken"] = "" + common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled) common.OptionMap["Notice"] = "" common.OptionMap["About"] = "" common.OptionMap["HomePageContent"] = "" @@ -97,7 +98,7 @@ func InitOptionMap() { common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled) - common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) + //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled) common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength) @@ -204,10 +205,12 @@ func updateOptionMap(key string, value string) (err error) { constant.CheckSensitiveEnabled = boolValue case "CheckSensitiveOnPromptEnabled": constant.CheckSensitiveOnPromptEnabled = boolValue - case "CheckSensitiveOnCompletionEnabled": - constant.CheckSensitiveOnCompletionEnabled = boolValue + //case "CheckSensitiveOnCompletionEnabled": + // constant.CheckSensitiveOnCompletionEnabled = boolValue case "StopOnSensitiveEnabled": constant.StopOnSensitiveEnabled = boolValue + case "SMTPSSLEnabled": + common.SMTPSSLEnabled = boolValue } } switch key { diff --git a/model/user.go b/model/user.go index e19dd23..22258d9 100644 --- a/model/user.go +++ b/model/user.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "strconv" "strings" "time" @@ -75,8 +76,26 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) { return users, err } -func SearchUsers(keyword string) (users []*User, err error) { - err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error +func SearchUsers(keyword string) ([]*User, error) { + var users []*User + var err error + + // 尝试将关键字转换为整数ID + keywordInt, err := strconv.Atoi(keyword) + if err == nil { + // 如果转换成功,按照ID搜索用户 + err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error + if err != nil || len(users) > 0 { + // 如果依据ID找到用户或者发生错误,返回结果或错误 + return users, err + } + } + + // 如果ID转换失败或者没有找到用户,依据其他字段进行模糊搜索 + err = DB.Unscoped().Omit("password"). + Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%"). + Find(&users).Error + return users, err } diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 935a0f0..d3886d5 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -15,7 +15,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string GetChannelName() string } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 155d5a0..bfe83db 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = aliStreamHandler(c, resp) } else { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index d8c96aa..d2571dc 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = baiduStreamHandler(c, resp) } else { diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index ee5a4c0..45efd01 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp) } else { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index f1a3c39..4de8dc0 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" "one-api/service" "strings" @@ -317,7 +316,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT }, nil } fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive()) + completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 5f78829..a275175 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = geminiChatStreamHandler(c, resp) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 31badd8..4a10a73 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -257,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive()) + completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index f66d9a9..4e1fd33 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -49,16 +49,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { - err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } } return diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index fa5f818..828ddea 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -45,19 +45,19 @@ func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbedding } } -func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { +func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var ollamaEmbeddingResponse OllamaEmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data = append(data, dto.OpenAIEmbeddingResponseItem{ @@ -77,7 +77,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in } doResponseBody, err := json.Marshal(embeddingResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. @@ -98,11 +98,11 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - return nil, usage, nil + return nil, usage } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 7dc591e..cab6a64 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -69,13 +69,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 39127de..fe5cd48 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -4,14 +4,10 @@ import ( "bufio" "bytes" "encoding/json" - "errors" - "fmt" "github.com/gin-gonic/gin" "io" - "log" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" "one-api/service" @@ -21,7 +17,7 @@ import ( ) func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { - checkSensitive := constant.ShouldCheckCompletionSensitive() + //checkSensitive := constant.ShouldCheckCompletionSensitive() var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -53,20 +49,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d if data[:6] != "data: " && data[:6] != "[DONE]" { continue } - sensitive := false - if checkSensitive { - // check sensitive - sensitive, _, data = service.SensitiveWordReplace(data, false) - } dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { streamItems = append(streamItems, data) } - if sensitive && constant.StopOnSensitiveEnabled { - dataChan <- "data: [DONE]" - break - } } streamResp := "[" + strings.Join(streamItems, ",") + "]" switch relayMode { @@ -142,118 +129,56 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d return nil, responseTextBuilder.String() } -func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { - var responseWithError dto.TextResponseWithError +func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var simpleResponse dto.SimpleResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - err = json.Unmarshal(responseBody, &responseWithError) + err = json.Unmarshal(responseBody, &simpleResponse) if err != nil { - log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err) - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - if responseWithError.Error.Type != "" { + if simpleResponse.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - Error: responseWithError.Error, + Error: simpleResponse.Error, StatusCode: resp.StatusCode, - }, nil, nil + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - checkSensitive := constant.ShouldCheckCompletionSensitive() - sensitiveWords := make([]string, 0) - triggerSensitive := false - - usage := &responseWithError.Usage - - //textResponse := &dto.TextResponse{ - // Choices: responseWithError.Choices, - // Usage: responseWithError.Usage, - //} - var doResponseBody []byte - - switch relayMode { - case relayconstant.RelayModeEmbeddings: - embeddingResponse := &dto.OpenAIEmbeddingResponse{ - Object: responseWithError.Object, - Data: responseWithError.Data, - Model: responseWithError.Model, - Usage: *usage, + if simpleResponse.Usage.TotalTokens == 0 { + completionTokens := 0 + for _, choice := range simpleResponse.Choices { + ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false) + completionTokens += ctkm } - doResponseBody, err = json.Marshal(embeddingResponse) - default: - if responseWithError.Usage.TotalTokens == 0 || checkSensitive { - completionTokens := 0 - for i, choice := range responseWithError.Choices { - stringContent := string(choice.Message.Content) - ctkm, _, _ := service.CountTokenText(stringContent, model, false) - completionTokens += ctkm - if checkSensitive { - sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false) - if sensitive { - triggerSensitive = true - msg := choice.Message - msg.Content = common.StringToByteSlice(stringContent) - responseWithError.Choices[i].Message = msg - sensitiveWords = append(sensitiveWords, words...) - } - } - } - responseWithError.Usage = dto.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - } - textResponse := &dto.TextResponse{ - Id: responseWithError.Id, - Created: responseWithError.Created, - Object: responseWithError.Object, - Choices: responseWithError.Choices, - Model: responseWithError.Model, - Usage: *usage, - } - doResponseBody, err = json.Marshal(textResponse) - } - - if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled { - sensitiveWords = common.RemoveDuplicate(sensitiveWords) - return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", - strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), - usage, &dto.SensitiveResponse{ - SensitiveWords: sensitiveWords, - } - } else { - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - // Copy headers - for k, v := range resp.Header { - // 删除任何现有的相同头部,以防止重复添加头部 - c.Writer.Header().Del(k) - for _, vv := range v { - c.Writer.Header().Add(k, vv) - } - } - // reset content length - c.Writer.Header().Del("Content-Length") - c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody))) - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + simpleResponse.Usage = dto.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, } } - return nil, usage, nil + return nil, &simpleResponse.Usage } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 2129ee3..4f59a44 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 4028269..3a7d4fa 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -157,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive()) + completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index d04af1e..24765ff 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index fa3da0c..470ec14 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 9ab2818..79a4b12 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { splits := strings.Split(info.ApiKey, "|") if len(splits) != 3 { - return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil + return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) } if a.request == nil { - return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil + return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) } if info.IsStream { err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 69c45a8..6f2d186 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = zhipuStreamHandler(c, resp) } else { diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index c7ea903..1b8866b 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 1790c57..2e94bc0 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -56,29 +56,29 @@ func Path2RelayMode(path string) int { func Path2RelayModeMidjourney(path string) int { relayMode := RelayModeUnknown - if strings.HasPrefix(path, "/mj/submit/action") { + if strings.HasSuffix(path, "/mj/submit/action") { // midjourney plus relayMode = RelayModeMidjourneyAction - } else if strings.HasPrefix(path, "/mj/submit/modal") { + } else if strings.HasSuffix(path, "/mj/submit/modal") { // midjourney plus relayMode = RelayModeMidjourneyModal - } else if strings.HasPrefix(path, "/mj/submit/shorten") { + } else if strings.HasSuffix(path, "/mj/submit/shorten") { // midjourney plus relayMode = RelayModeMidjourneyShorten - } else if strings.HasPrefix(path, "/mj/insight-face/swap") { + } else if strings.HasSuffix(path, "/mj/insight-face/swap") { // midjourney plus relayMode = RelayModeSwapFace - } else if strings.HasPrefix(path, "/mj/submit/imagine") { + } else if strings.HasSuffix(path, "/mj/submit/imagine") { relayMode = RelayModeMidjourneyImagine - } else if strings.HasPrefix(path, "/mj/submit/blend") { + } else if strings.HasSuffix(path, "/mj/submit/blend") { relayMode = RelayModeMidjourneyBlend - } else if strings.HasPrefix(path, "/mj/submit/describe") { + } else if strings.HasSuffix(path, "/mj/submit/describe") { relayMode = RelayModeMidjourneyDescribe - } else if strings.HasPrefix(path, "/mj/notify") { + } else if strings.HasSuffix(path, "/mj/notify") { relayMode = RelayModeMidjourneyNotify - } else if strings.HasPrefix(path, "/mj/submit/change") { + } else if strings.HasSuffix(path, "/mj/submit/change") { relayMode = RelayModeMidjourneyChange - } else if strings.HasPrefix(path, "/mj/submit/simple-change") { + } else if strings.HasSuffix(path, "/mj/submit/simple-change") { relayMode = RelayModeMidjourneyChange } else if strings.HasSuffix(path, "/fetch") { relayMode = RelayModeMidjourneyTaskFetch diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 1c0f868..d4458ce 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if strings.HasPrefix(audioRequest.Model, "tts-1") { quota = promptTokens } else { - quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive()) + quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false) } quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 3cd42cb..7b3a4e2 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -180,7 +180,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { Description: "quota_not_enough", } } - requestURL := c.Request.URL.String() + requestURL := getMjRequestPath(c.Request.URL.String()) baseURL := c.GetString("base_url") fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) @@ -260,7 +260,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { c.Set("channel_id", originTask.ChannelId) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - requestURL := c.Request.URL.String() + requestURL := getMjRequestPath(c.Request.URL.String()) fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) if err != nil { @@ -440,7 +440,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } //baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() + requestURL := getMjRequestPath(c.Request.URL.String()) baseURL := c.GetString("base_url") @@ -605,3 +605,15 @@ type taskChangeParams struct { Action string Index int } + +func getMjRequestPath(path string) string { + requestURL := path + if strings.Contains(requestURL, "/mj-") { + urls := strings.Split(requestURL, "/mj/") + if len(urls) < 2 { + return requestURL + } + requestURL = "/mj/" + urls[1] + } + return requestURL +} diff --git a/relay/relay-text.go b/relay/relay-text.go index ec64e04..ff653ff 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -165,21 +165,12 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode) } - usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) if openaiErr != nil { - if sensitiveResp == nil { // 如果没有敏感词检查结果 - returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - return openaiErr - } else { - // 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额 - postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp) - if constant.StopOnSensitiveEnabled { // 是否直接返回错误 - return openaiErr - } - return nil - } + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + return openaiErr } - postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil) + postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) return nil } @@ -258,7 +249,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, sensitiveResp *dto.SensitiveResponse) { + modelPrice float64) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -293,9 +284,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe logContent += fmt.Sprintf("(可能是上游超时)") common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota)) } else { - if sensitiveResp != nil { - logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) - } + //if sensitiveResp != nil { + // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + //} quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { diff --git a/router/relay-router.go b/router/relay-router.go index 4addee0..2d8e7b3 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -43,7 +43,16 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) relayV1Router.POST("/moderations", controller.Relay) } + relayMjRouter := router.Group("/mj") + registerMjRouterGroup(relayMjRouter) + + relayMjModeRouter := router.Group("/:mode/mj") + registerMjRouterGroup(relayMjModeRouter) + //relayMjRouter.Use() +} + +func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage) relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { @@ -61,5 +70,4 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney) } - //relayMjRouter.Use() } diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 3019906..f42fe57 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -330,21 +330,21 @@ const OperationSetting = () => { name='CheckSensitiveOnPromptEnabled' onChange={handleInputChange} /> - - - - + {/**/} + {/**/} + {/* */} + {/**/} {/**/} {/* { RegisterEnabled: '', UserSelfDeletionEnabled: false, EmailDomainRestrictionEnabled: '', + SMTPSSLEnabled: '', EmailDomainWhitelist: [], // telegram login TelegramOAuthEnabled: '', @@ -104,6 +105,7 @@ const SystemSetting = () => { case 'TelegramOAuthEnabled': case 'TurnstileCheckEnabled': case 'EmailDomainRestrictionEnabled': + case 'SMTPSSLEnabled': case 'RegisterEnabled': case 'UserSelfDeletionEnabled': case 'PaymentEnabled': @@ -139,7 +141,7 @@ const SystemSetting = () => { } if ( name === 'Notice' || - name.startsWith('SMTP') || + (name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') || name === 'ServerAddress' || name === 'StripeApiSecret' || name === 'StripeWebhookSecret' || @@ -652,6 +654,14 @@ const SystemSetting = () => { placeholder='敏感信息不会发送到前端显示' /> + + + 保存 SMTP 设置