From 5c39f54040d0bd4750de8f5139bf911dbeb8e0be Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Thu, 28 Mar 2024 12:18:11 +0800
Subject: [PATCH 1/8] feat: able to set smtp ssl
---
common/constants.go | 1 +
common/email.go | 2 +-
model/option.go | 3 +++
web/src/components/SystemSetting.js | 12 +++++++++++-
4 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/common/constants.go b/common/constants.go
index 4000e08..85ecadd 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -75,6 +75,7 @@ var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
+var SMTPSSLEnabled = false
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
diff --git a/common/email.go b/common/email.go
index 5d3ef0d..13345d8 100644
--- a/common/email.go
+++ b/common/email.go
@@ -24,7 +24,7 @@ func SendEmail(subject string, receiver string, content string) error {
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
var err error
- if SMTPPort == 465 {
+ if SMTPPort == 465 || SMTPSSLEnabled {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: SMTPServer,
diff --git a/model/option.go b/model/option.go
index 46e41da..2ccfe03 100644
--- a/model/option.go
+++ b/model/option.go
@@ -50,6 +50,7 @@ func InitOptionMap() {
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
common.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPToken"] = ""
+ common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
common.OptionMap["Notice"] = ""
common.OptionMap["About"] = ""
common.OptionMap["HomePageContent"] = ""
@@ -199,6 +200,8 @@ func updateOptionMap(key string, value string) (err error) {
constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
constant.StopOnSensitiveEnabled = boolValue
+ case "SMTPSSLEnabled":
+ common.SMTPSSLEnabled = boolValue
}
}
switch key {
diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js
index c305d2c..8d13ef1 100644
--- a/web/src/components/SystemSetting.js
+++ b/web/src/components/SystemSetting.js
@@ -42,6 +42,7 @@ const SystemSetting = () => {
TurnstileSecretKey: '',
RegisterEnabled: '',
EmailDomainRestrictionEnabled: '',
+ SMTPSSLEnabled: '',
EmailDomainWhitelist: [],
// telegram login
TelegramOAuthEnabled: '',
@@ -98,6 +99,7 @@ const SystemSetting = () => {
case 'TelegramOAuthEnabled':
case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled':
+ case 'SMTPSSLEnabled':
case 'RegisterEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
@@ -134,7 +136,7 @@ const SystemSetting = () => {
}
if (
name === 'Notice' ||
- name.startsWith('SMTP') ||
+ (name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
name === 'ServerAddress' ||
name === 'EpayId' ||
name === 'EpayKey' ||
@@ -570,6 +572,14 @@ const SystemSetting = () => {
placeholder='敏感信息不会发送到前端显示'
/>
+
+
+
保存 SMTP 设置
From 49df4b6eed1d2888769e34065aa6d3c593ad6200 Mon Sep 17 00:00:00 2001
From: Xiangyuan Liu
Date: Fri, 29 Mar 2024 16:48:50 +0800
Subject: [PATCH 2/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20/mj-{mode}=20?=
=?UTF-8?q?=E8=B7=AF=E5=BE=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
middleware/distributor.go | 2 +-
router/relay-router.go | 10 +++++++++-
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index a5e40b0..10696a9 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -44,7 +44,7 @@ func Distribute() func(c *gin.Context) {
// Select a channel for the user
var modelRequest ModelRequest
var err error
- if strings.HasPrefix(c.Request.URL.Path, "/mj") {
+ if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
diff --git a/router/relay-router.go b/router/relay-router.go
index 4addee0..2d8e7b3 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -43,7 +43,16 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay)
}
+
relayMjRouter := router.Group("/mj")
+ registerMjRouterGroup(relayMjRouter)
+
+ relayMjModeRouter := router.Group("/:mode/mj")
+ registerMjRouterGroup(relayMjModeRouter)
+ //relayMjRouter.Use()
+}
+
+func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
@@ -61,5 +70,4 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
}
- //relayMjRouter.Use()
}
From 2e595bdafb2705fb9dec2db201b8d13366e4cf50 Mon Sep 17 00:00:00 2001
From: Xiangyuan Liu
Date: Fri, 29 Mar 2024 16:58:19 +0800
Subject: [PATCH 3/8] =?UTF-8?q?fix:=20=E6=94=AF=E6=8C=81=20/mj-{mode}=20?=
=?UTF-8?q?=E8=B7=AF=E5=BE=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
relay/relay-mj.go | 18 +++++++++++++++---
1 file changed, 15 insertions(+), 3 deletions(-)
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 3cd42cb..7b3a4e2 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -180,7 +180,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
Description: "quota_not_enough",
}
}
- requestURL := c.Request.URL.String()
+ requestURL := getMjRequestPath(c.Request.URL.String())
baseURL := c.GetString("base_url")
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
@@ -260,7 +260,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- requestURL := c.Request.URL.String()
+ requestURL := getMjRequestPath(c.Request.URL.String())
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
if err != nil {
@@ -440,7 +440,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
//baseURL := common.ChannelBaseURLs[channelType]
- requestURL := c.Request.URL.String()
+ requestURL := getMjRequestPath(c.Request.URL.String())
baseURL := c.GetString("base_url")
@@ -605,3 +605,15 @@ type taskChangeParams struct {
Action string
Index int
}
+
+func getMjRequestPath(path string) string {
+ requestURL := path
+ if strings.Contains(requestURL, "/mj-") {
+ urls := strings.Split(requestURL, "/mj/")
+ if len(urls) < 2 {
+ return requestURL
+ }
+ requestURL = "/mj/" + urls[1]
+ }
+ return requestURL
+}
From 3065bf92ae4ccff27c9b164e0a3816f2dff2a974 Mon Sep 17 00:00:00 2001
From: Xiangyuan Liu
Date: Fri, 29 Mar 2024 17:36:44 +0800
Subject: [PATCH 4/8] =?UTF-8?q?fix:=20=E6=94=AF=E6=8C=81=20/mj-{mode}=20?=
=?UTF-8?q?=E8=B7=AF=E5=BE=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
relay/constant/relay_mode.go | 20 ++++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index 1790c57..2e94bc0 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -56,29 +56,29 @@ func Path2RelayMode(path string) int {
func Path2RelayModeMidjourney(path string) int {
relayMode := RelayModeUnknown
- if strings.HasPrefix(path, "/mj/submit/action") {
+ if strings.HasSuffix(path, "/mj/submit/action") {
// midjourney plus
relayMode = RelayModeMidjourneyAction
- } else if strings.HasPrefix(path, "/mj/submit/modal") {
+ } else if strings.HasSuffix(path, "/mj/submit/modal") {
// midjourney plus
relayMode = RelayModeMidjourneyModal
- } else if strings.HasPrefix(path, "/mj/submit/shorten") {
+ } else if strings.HasSuffix(path, "/mj/submit/shorten") {
// midjourney plus
relayMode = RelayModeMidjourneyShorten
- } else if strings.HasPrefix(path, "/mj/insight-face/swap") {
+ } else if strings.HasSuffix(path, "/mj/insight-face/swap") {
// midjourney plus
relayMode = RelayModeSwapFace
- } else if strings.HasPrefix(path, "/mj/submit/imagine") {
+ } else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
- } else if strings.HasPrefix(path, "/mj/submit/blend") {
+ } else if strings.HasSuffix(path, "/mj/submit/blend") {
relayMode = RelayModeMidjourneyBlend
- } else if strings.HasPrefix(path, "/mj/submit/describe") {
+ } else if strings.HasSuffix(path, "/mj/submit/describe") {
relayMode = RelayModeMidjourneyDescribe
- } else if strings.HasPrefix(path, "/mj/notify") {
+ } else if strings.HasSuffix(path, "/mj/notify") {
relayMode = RelayModeMidjourneyNotify
- } else if strings.HasPrefix(path, "/mj/submit/change") {
+ } else if strings.HasSuffix(path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange
- } else if strings.HasPrefix(path, "/mj/submit/simple-change") {
+ } else if strings.HasSuffix(path, "/mj/submit/simple-change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasSuffix(path, "/fetch") {
relayMode = RelayModeMidjourneyTaskFetch
From 44a8ade4bac9a35af0bd683ec9f0e594c6bca392 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Fri, 29 Mar 2024 22:20:14 +0800
Subject: [PATCH 5/8] fix: remove sensitive check on completion (close #157)
---
constant/sensitive.go | 9 +-
controller/channel-test.go | 2 +-
dto/text_response.go | 6 +
model/option.go | 6 +-
relay/channel/adapter.go | 2 +-
relay/channel/ali/adaptor.go | 2 +-
relay/channel/baidu/adaptor.go | 2 +-
relay/channel/claude/adaptor.go | 2 +-
relay/channel/claude/relay-claude.go | 3 +-
relay/channel/gemini/adaptor.go | 2 +-
relay/channel/gemini/relay-gemini.go | 3 +-
relay/channel/ollama/adaptor.go | 6 +-
relay/channel/ollama/relay-ollama.go | 16 +--
relay/channel/openai/adaptor.go | 4 +-
relay/channel/openai/relay-openai.go | 151 +++++++------------------
relay/channel/palm/adaptor.go | 2 +-
relay/channel/palm/relay-palm.go | 3 +-
relay/channel/perplexity/adaptor.go | 4 +-
relay/channel/tencent/adaptor.go | 2 +-
relay/channel/xunfei/adaptor.go | 6 +-
relay/channel/zhipu/adaptor.go | 2 +-
relay/channel/zhipu_4v/adaptor.go | 4 +-
relay/relay-audio.go | 2 +-
relay/relay-text.go | 25 ++--
web/src/components/OperationSetting.js | 28 ++---
25 files changed, 107 insertions(+), 187 deletions(-)
diff --git a/constant/sensitive.go b/constant/sensitive.go
index 8d8e15f..5297560 100644
--- a/constant/sensitive.go
+++ b/constant/sensitive.go
@@ -4,7 +4,8 @@ import "strings"
var CheckSensitiveEnabled = true
var CheckSensitiveOnPromptEnabled = true
-var CheckSensitiveOnCompletionEnabled = true
+
+//var CheckSensitiveOnCompletionEnabled = true
// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
var StopOnSensitiveEnabled = true
@@ -37,6 +38,6 @@ func ShouldCheckPromptSensitive() bool {
return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
}
-func ShouldCheckCompletionSensitive() bool {
- return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
-}
+//func ShouldCheckCompletionSensitive() bool {
+// return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
+//}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 6d64a24..a4dcfe9 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
- usage, respErr, _ := adaptor.DoResponse(c, resp, meta)
+ usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
diff --git a/dto/text_response.go b/dto/text_response.go
index 16deb0d..98275fe 100644
--- a/dto/text_response.go
+++ b/dto/text_response.go
@@ -11,6 +11,12 @@ type TextResponseWithError struct {
Error OpenAIError `json:"error"`
}
+type SimpleResponse struct {
+ Usage `json:"usage"`
+ Error OpenAIError `json:"error"`
+ Choices []OpenAITextResponseChoice `json:"choices"`
+}
+
type TextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
diff --git a/model/option.go b/model/option.go
index 2ccfe03..2bfa22a 100644
--- a/model/option.go
+++ b/model/option.go
@@ -93,7 +93,7 @@ func InitOptionMap() {
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
- common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
+ //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
@@ -196,8 +196,8 @@ func updateOptionMap(key string, value string) (err error) {
constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
constant.CheckSensitiveOnPromptEnabled = boolValue
- case "CheckSensitiveOnCompletionEnabled":
- constant.CheckSensitiveOnCompletionEnabled = boolValue
+ //case "CheckSensitiveOnCompletionEnabled":
+ // constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
constant.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
index 935a0f0..d3886d5 100644
--- a/relay/channel/adapter.go
+++ b/relay/channel/adapter.go
@@ -15,7 +15,7 @@ type Adaptor interface {
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
- DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse)
+ DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
index 155d5a0..bfe83db 100644
--- a/relay/channel/ali/adaptor.go
+++ b/relay/channel/ali/adaptor.go
@@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = aliStreamHandler(c, resp)
} else {
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
index d8c96aa..d2571dc 100644
--- a/relay/channel/baidu/adaptor.go
+++ b/relay/channel/baidu/adaptor.go
@@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = baiduStreamHandler(c, resp)
} else {
diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go
index ee5a4c0..45efd01 100644
--- a/relay/channel/claude/adaptor.go
+++ b/relay/channel/claude/adaptor.go
@@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
} else {
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index f1a3c39..4de8dc0 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -8,7 +8,6 @@ import (
"io"
"net/http"
"one-api/common"
- "one-api/constant"
"one-api/dto"
"one-api/service"
"strings"
@@ -317,7 +316,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
}, nil
}
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
- completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
+ completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index 5f78829..a275175 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index 31badd8..4a10a73 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -7,7 +7,6 @@ import (
"io"
"net/http"
"one-api/common"
- "one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -257,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
- completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
+ completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index f66d9a9..4e1fd33 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -49,16 +49,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
- err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+ err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} else {
- err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+ err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
}
return
diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go
index fa5f818..828ddea 100644
--- a/relay/channel/ollama/relay-ollama.go
+++ b/relay/channel/ollama/relay-ollama.go
@@ -45,19 +45,19 @@ func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbedding
}
}
-func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
+func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
@@ -77,7 +77,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
}
doResponseBody, err := json.Marshal(embeddingResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
@@ -98,11 +98,11 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- return nil, usage, nil
+ return nil, usage
}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index 7dc591e..cab6a64 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -69,13 +69,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
- err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+ err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 39127de..fe5cd48 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -4,14 +4,10 @@ import (
"bufio"
"bytes"
"encoding/json"
- "errors"
- "fmt"
"github.com/gin-gonic/gin"
"io"
- "log"
"net/http"
"one-api/common"
- "one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"one-api/service"
@@ -21,7 +17,7 @@ import (
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
- checkSensitive := constant.ShouldCheckCompletionSensitive()
+ //checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -53,20 +49,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
- sensitive := false
- if checkSensitive {
- // check sensitive
- sensitive, _, data = service.SensitiveWordReplace(data, false)
- }
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data)
}
- if sensitive && constant.StopOnSensitiveEnabled {
- dataChan <- "data: [DONE]"
- break
- }
}
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
@@ -142,118 +129,56 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
return nil, responseTextBuilder.String()
}
-func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
- var responseWithError dto.TextResponseWithError
+func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+ var simpleResponse dto.SimpleResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- err = json.Unmarshal(responseBody, &responseWithError)
+ err = json.Unmarshal(responseBody, &simpleResponse)
if err != nil {
- log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
- if responseWithError.Error.Type != "" {
+ if simpleResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
- Error: responseWithError.Error,
+ Error: simpleResponse.Error,
StatusCode: resp.StatusCode,
- }, nil, nil
+ }, nil
+ }
+ // Reset response body
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ // We shouldn't set the header before we parse the response body, because the parse part may fail.
+ // And then we will have to send an error response, but in this case, the header has already been set.
+ // So the httpClient will be confused by the response.
+ // For example, Postman will report error, and we cannot check the response at all.
+ for k, v := range resp.Header {
+ c.Writer.Header().Set(k, v[0])
+ }
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = io.Copy(c.Writer, resp.Body)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- checkSensitive := constant.ShouldCheckCompletionSensitive()
- sensitiveWords := make([]string, 0)
- triggerSensitive := false
-
- usage := &responseWithError.Usage
-
- //textResponse := &dto.TextResponse{
- // Choices: responseWithError.Choices,
- // Usage: responseWithError.Usage,
- //}
- var doResponseBody []byte
-
- switch relayMode {
- case relayconstant.RelayModeEmbeddings:
- embeddingResponse := &dto.OpenAIEmbeddingResponse{
- Object: responseWithError.Object,
- Data: responseWithError.Data,
- Model: responseWithError.Model,
- Usage: *usage,
+ if simpleResponse.Usage.TotalTokens == 0 {
+ completionTokens := 0
+ for _, choice := range simpleResponse.Choices {
+ ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
+ completionTokens += ctkm
}
- doResponseBody, err = json.Marshal(embeddingResponse)
- default:
- if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
- completionTokens := 0
- for i, choice := range responseWithError.Choices {
- stringContent := string(choice.Message.Content)
- ctkm, _, _ := service.CountTokenText(stringContent, model, false)
- completionTokens += ctkm
- if checkSensitive {
- sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
- if sensitive {
- triggerSensitive = true
- msg := choice.Message
- msg.Content = common.StringToByteSlice(stringContent)
- responseWithError.Choices[i].Message = msg
- sensitiveWords = append(sensitiveWords, words...)
- }
- }
- }
- responseWithError.Usage = dto.Usage{
- PromptTokens: promptTokens,
- CompletionTokens: completionTokens,
- TotalTokens: promptTokens + completionTokens,
- }
- }
- textResponse := &dto.TextResponse{
- Id: responseWithError.Id,
- Created: responseWithError.Created,
- Object: responseWithError.Object,
- Choices: responseWithError.Choices,
- Model: responseWithError.Model,
- Usage: *usage,
- }
- doResponseBody, err = json.Marshal(textResponse)
- }
-
- if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
- sensitiveWords = common.RemoveDuplicate(sensitiveWords)
- return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
- strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
- usage, &dto.SensitiveResponse{
- SensitiveWords: sensitiveWords,
- }
- } else {
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- // Copy headers
- for k, v := range resp.Header {
- // 删除任何现有的相同头部,以防止重复添加头部
- c.Writer.Header().Del(k)
- for _, vv := range v {
- c.Writer.Header().Add(k, vv)
- }
- }
- // reset content length
- c.Writer.Header().Del("Content-Length")
- c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+ simpleResponse.Usage = dto.Usage{
+ PromptTokens: promptTokens,
+ CompletionTokens: completionTokens,
+ TotalTokens: promptTokens + completionTokens,
}
}
- return nil, usage, nil
+ return nil, &simpleResponse.Usage
}
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
index 2129ee3..4f59a44 100644
--- a/relay/channel/palm/adaptor.go
+++ b/relay/channel/palm/adaptor.go
@@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go
index 4028269..3a7d4fa 100644
--- a/relay/channel/palm/relay-palm.go
+++ b/relay/channel/palm/relay-palm.go
@@ -7,7 +7,6 @@ import (
"io"
"net/http"
"one-api/common"
- "one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -157,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
- completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
+ completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go
index d04af1e..24765ff 100644
--- a/relay/channel/perplexity/adaptor.go
+++ b/relay/channel/perplexity/adaptor.go
@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
- err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+ err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
index fa3da0c..470ec14 100644
--- a/relay/channel/tencent/adaptor.go
+++ b/relay/channel/tencent/adaptor.go
@@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
index 9ab2818..79a4b12 100644
--- a/relay/channel/xunfei/adaptor.go
+++ b/relay/channel/xunfei/adaptor.go
@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return dummyResp, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 {
- return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil
+ return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
if a.request == nil {
- return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil
+ return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
}
if info.IsStream {
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
index 69c45a8..6f2d186 100644
--- a/relay/channel/zhipu/adaptor.go
+++ b/relay/channel/zhipu/adaptor.go
@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = zhipuStreamHandler(c, resp)
} else {
diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go
index c7ea903..1b8866b 100644
--- a/relay/channel/zhipu_4v/adaptor.go
+++ b/relay/channel/zhipu_4v/adaptor.go
@@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
- err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+ err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
diff --git a/relay/relay-audio.go b/relay/relay-audio.go
index 1c0f868..d4458ce 100644
--- a/relay/relay-audio.go
+++ b/relay/relay-audio.go
@@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = promptTokens
} else {
- quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
+ quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false)
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
diff --git a/relay/relay-text.go b/relay/relay-text.go
index ec64e04..ff653ff 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -165,21 +165,12 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode)
}
- usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo)
+ usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
- if sensitiveResp == nil { // 如果没有敏感词检查结果
- returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
- return openaiErr
- } else {
- // 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额
- postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp)
- if constant.StopOnSensitiveEnabled { // 是否直接返回错误
- return openaiErr
- }
- return nil
- }
+ returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+ return openaiErr
}
- postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil)
+ postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
return nil
}
@@ -258,7 +249,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, sensitiveResp *dto.SensitiveResponse) {
+ modelPrice float64) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
@@ -293,9 +284,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
} else {
- if sensitiveResp != nil {
- logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
- }
+ //if sensitiveResp != nil {
+ // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
+ //}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index 3019906..f42fe57 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -330,21 +330,21 @@ const OperationSetting = () => {
name='CheckSensitiveOnPromptEnabled'
onChange={handleInputChange}
/>
-
-
-
-
+ {/**/}
+ {/**/}
+ {/* */}
+ {/**/}
{/**/}
{/*
Date: Fri, 29 Mar 2024 22:48:37 +0800
Subject: [PATCH 6/8] fix: SearchUsers (close #160)
---
model/user.go | 23 +++++++++++++++++++++--
1 file changed, 21 insertions(+), 2 deletions(-)
diff --git a/model/user.go b/model/user.go
index 00294b2..3d442b3 100644
--- a/model/user.go
+++ b/model/user.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "strconv"
"strings"
"time"
@@ -72,8 +73,26 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
return users, err
}
-func SearchUsers(keyword string) (users []*User, err error) {
- err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
+func SearchUsers(keyword string) ([]*User, error) {
+ var users []*User
+ var err error
+
+ // 尝试将关键字转换为整数ID
+ keywordInt, err := strconv.Atoi(keyword)
+ if err == nil {
+ // 如果转换成功,按照ID搜索用户
+ err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
+ if err != nil || len(users) > 0 {
+ // 如果依据ID找到用户或者发生错误,返回结果或错误
+ return users, err
+ }
+ }
+
+ // 如果ID转换失败或者没有找到用户,依据其他字段进行模糊搜索
+ err = DB.Unscoped().Omit("password").
+ Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
+ Find(&users).Error
+
return users, err
}
From 706449dede9bda59f617dd2dac971f7f59a345e3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=BD=99=E7=94=9F=E4=B8=80=E4=B8=AA=E7=99=BD=E6=81=A9?=
<591698275@qq.com>
Date: Sat, 30 Mar 2024 13:21:05 +0800
Subject: [PATCH 7/8] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=B8=8A=E6=B8=B8?=
=?UTF-8?q?=E6=9E=84=E5=9B=BE=E5=A4=B1=E8=B4=A5=E5=88=A4=E6=96=AD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
controller/midjourney.go | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 41db4bf..313bae2 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"log"
"net/http"
@@ -16,6 +15,8 @@ import (
"strconv"
"strings"
"time"
+
+ "github.com/gin-gonic/gin"
)
func UpdateMidjourneyTaskBulk() {
@@ -147,7 +148,7 @@ func UpdateMidjourneyTaskBulk() {
task.Buttons = string(buttonStr)
}
- if task.Progress != "100%" && responseItem.FailReason != "" {
+ if task.Progress != "100%" && responseItem.FailReason != "" || task.Progress == "100%" && task.Status == "FAILURE" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
From f7a4f18aff63daf318884b5bb99c67537d45cad1 Mon Sep 17 00:00:00 2001
From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com>
Date: Sat, 30 Mar 2024 16:26:39 +0800
Subject: [PATCH 8/8] Update midjourney.go
---
controller/midjourney.go | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 313bae2..b5b832b 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "github.com/gin-gonic/gin"
"io"
"log"
"net/http"
@@ -15,8 +16,6 @@ import (
"strconv"
"strings"
"time"
-
- "github.com/gin-gonic/gin"
)
func UpdateMidjourneyTaskBulk() {
@@ -148,7 +147,7 @@ func UpdateMidjourneyTaskBulk() {
task.Buttons = string(buttonStr)
}
- if task.Progress != "100%" && responseItem.FailReason != "" || task.Progress == "100%" && task.Status == "FAILURE" {
+ if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)