From 462c328d4bc33e927c1ac0de5b709f1274492f06 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 6 Apr 2024 20:45:18 +0800 Subject: [PATCH 01/11] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=9C=AA?= =?UTF-8?q?=E5=BC=80=E5=90=AF=E7=BC=93=E5=AD=98=E4=B8=8B=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=E9=87=8D=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/ability.go | 52 +++++++++++++++++++++++++++++++++++++++++++++--- model/cache.go | 2 +- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/model/ability.go b/model/ability.go index f522967..01fea9e 100644 --- a/model/ability.go +++ b/model/ability.go @@ -3,6 +3,7 @@ package model import ( "errors" "fmt" + "gorm.io/gorm" "one-api/common" "strings" ) @@ -27,8 +28,7 @@ func GetGroupModels(group string) []string { return models } -func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - var abilities []Ability +func getPriority(group string, model string, retry int) (int, error) { groupCol := "`group`" trueVal := "1" if common.UsingPostgreSQL { @@ -36,9 +36,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { trueVal = "true" } - var err error = nil + var priorities []int + err := DB.Model(&Ability{}). + Select("DISTINCT(priority)"). + Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model). + Order("priority DESC"). // 按优先级降序排序 + Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 + + if err != nil { + // 处理错误 + return 0, err + } + + // 确定要使用的优先级 + var priorityToUse int + if retry >= len(priorities) { + // 如果重试次数大于优先级数,则使用最小的优先级 + priorityToUse = priorities[len(priorities)-1] + } else { + priorityToUse = priorities[retry] + } + return priorityToUse, nil +} + +func getChannelQuery(group string, model string, retry int) *gorm.DB { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + if retry != 0 { + priority, err := getPriority(group, model, retry) + if err != nil { + common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error())) + } else { + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority) + } + } + + return channelQuery +} + +func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { + var abilities []Ability + + var err error = nil + channelQuery := getChannelQuery(group, model, retry) if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("weight DESC").Find(&abilities).Error } else { diff --git a/model/cache.go b/model/cache.go index 78bdc17..dc2ed3b 100644 --- a/model/cache.go +++ b/model/cache.go @@ -272,7 +272,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha // if memory cache is disabled, get channel directly from database if !common.MemoryCacheEnabled { - return GetRandomSatisfiedChannel(group, model) + return GetRandomSatisfiedChannel(group, model, retry) } channelSyncLock.RLock() defer channelSyncLock.RUnlock() From 497cc32634fe75880e86708a1920b7d5fba297f3 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 6 Apr 2024 20:47:03 +0800 Subject: [PATCH 02/11] update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index abb2379..7c878af 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 ## 渠道重试 -渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试。 +渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。 如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 ### 缓存设置方法 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 From fbdb17022ca42e0665065f72d992f993bc878ca5 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 6 Apr 2024 20:49:19 +0800 Subject: [PATCH 03/11] update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7c878af..41cbec0 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ ## 渠道重试 渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。 -如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 +如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 ### 缓存设置方法 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` From 5961de03e73cc1eb7de12c02a9a7b2aad9e29c28 Mon Sep 17 00:00:00 2001 From: iszcz <74706321+iszcz@users.noreply.github.com> Date: Sat, 6 Apr 2024 22:59:23 +0800 Subject: [PATCH 04/11] =?UTF-8?q?=E6=B8=85=E9=99=A4--mode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- constant/midjourney.go | 2 ++ model/option.go | 3 +++ service/midjourney.go | 9 +++++++++ web/src/components/OperationSetting.js | 7 +++++++ 4 files changed, 21 insertions(+) diff --git a/constant/midjourney.go b/constant/midjourney.go index 8b88a44..6d0b5ac 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -2,6 +2,8 @@ package constant var MjNotifyEnabled = false +var MjModeClearEnabled = false + const ( MjErrorUnknown = 5 MjRequestError = 4 diff --git a/model/option.go b/model/option.go index 057d3b7..8432141 100644 --- a/model/option.go +++ b/model/option.go @@ -92,6 +92,7 @@ func InitOptionMap() { common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled) + common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled) //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) @@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) { common.DefaultCollapseSidebar = boolValue case "MjNotifyEnabled": constant.MjNotifyEnabled = boolValue + case "MjModeClearEnabled": + constant.MjModeClearEnabled = boolValue case "CheckSensitiveEnabled": constant.CheckSensitiveEnabled = boolValue case "CheckSensitiveOnPromptEnabled": diff --git a/service/midjourney.go b/service/midjourney.go index ae13464..ccf5141 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -172,6 +172,15 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) // make new request with mapResult } + if constant.MjModeClearEnabled { + if prompt, ok := mapResult["prompt"].(string); ok { + prompt = strings.Replace(prompt, "--fast", "", -1) + prompt = strings.Replace(prompt, "--relax", "", -1) + prompt = strings.Replace(prompt, "--turbo", "", -1) + + mapResult["prompt"] = prompt + } + } reqBody, err := json.Marshal(mapResult) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index f42fe57..14b70d7 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -36,6 +36,7 @@ const OperationSetting = () => { StopOnSensitiveEnabled: '', SensitiveWords: '', MjNotifyEnabled: '', + MjModeClearEnabled: '', DrawingEnabled: '', DataExportEnabled: '', DataExportDefaultTime: 'hour', @@ -312,6 +313,12 @@ const OperationSetting = () => { name='MjNotifyEnabled' onChange={handleInputChange} /> +
屏蔽词过滤设置
From 2d1d1b463190ffec0d7116ac94dafcc5d28f94e8 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 7 Apr 2024 14:42:03 +0800 Subject: [PATCH 05/11] update go-epay --- controller/topup.go | 5 +++-- go.mod | 8 ++++---- go.sum | 10 ++++++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/controller/topup.go b/controller/topup.go index e938a15..08493f9 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -2,9 +2,10 @@ package controller import ( "fmt" + "github.com/Calcium-Ion/go-epay/epay" "github.com/gin-gonic/gin" "github.com/samber/lo" - epay "github.com/star-horizon/go-epay" + "log" "net/url" "one-api/common" @@ -30,7 +31,7 @@ func GetEpayClient() *epay.Client { if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" { return nil } - withUrl, err := epay.NewClientWithUrl(&epay.Config{ + withUrl, err := epay.NewClient(&epay.Config{ PartnerID: common.EpayId, Key: common.EpayKey, }, common.PayAddress) diff --git a/go.mod b/go.mod index b0c7220..62bc80e 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ module one-api go 1.18 require ( + github.com/Calcium-Ion/go-epay v0.0.2 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 @@ -16,9 +17,8 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.6 - github.com/samber/lo v1.38.1 + github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible - github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 golang.org/x/crypto v0.21.0 golang.org/x/image v0.15.0 gorm.io/driver/mysql v1.4.3 @@ -65,9 +65,9 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect + golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/net v0.21.0 // indirect - golang.org/x/sync v0.1.0 // indirect + golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect diff --git a/go.sum b/go.sum index 5b17b48..5bb3189 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/Calcium-Ion/go-epay v0.0.1 h1:cRCvwNTkPmmLM5od0p4w0cTcYcAPaAVLYr41ujseDcc= +github.com/Calcium-Ion/go-epay v0.0.1/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= +github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc= +github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= @@ -137,6 +141,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= +github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI= @@ -175,6 +181,8 @@ golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= +golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -182,6 +190,8 @@ golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= From 34bf8f8945e4c7952fbdd0802a4f099c312e2f63 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 7 Apr 2024 22:08:11 +0800 Subject: [PATCH 06/11] fix: select channel --- middleware/distributor.go | 141 ++++++++++++++++++++------------------ 1 file changed, 73 insertions(+), 68 deletions(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 35cb6df..e922662 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -24,6 +24,9 @@ func Distribute() func(c *gin.Context) { userId := c.GetInt("id") var channel *model.Channel channelId, ok := c.Get("specific_channel_id") + modelRequest, shouldSelectChannel, err := getModelRequest(c) + userGroup, _ := model.CacheGetUserGroup(userId) + c.Set("group", userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -40,72 +43,7 @@ func Distribute() func(c *gin.Context) { return } } else { - shouldSelectChannel := true // Select a channel for the user - var modelRequest ModelRequest - var err error - if strings.Contains(c.Request.URL.Path, "/mj/") { - relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) - if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || - relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || - relayMode == relayconstant.RelayModeMidjourneyNotify || - relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { - shouldSelectChannel = false - } else { - midjourneyRequest := dto.MidjourneyRequest{} - err = common.UnmarshalBodyReusable(c, &midjourneyRequest) - if err != nil { - abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) - return - } - midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) - if mjErr != nil { - abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) - return - } - if midjourneyModel == "" { - if !success { - abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") - return - } else { - // task fetch, task fetch by condition, notify - shouldSelectChannel = false - } - } - modelRequest.Model = midjourneyModel - } - c.Set("relay_mode", relayMode) - } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - } - if err != nil { - abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - if modelRequest.Model == "" { - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - modelRequest.Model = "tts-1" - } else { - modelRequest.Model = "whisper-1" - } - } - } // check token model mapping modelLimitEnable := c.GetBool("token_model_limit_enabled") if modelLimitEnable { @@ -128,8 +66,6 @@ func Distribute() func(c *gin.Context) { } } - userGroup, _ := model.CacheGetUserGroup(userId) - c.Set("group", userGroup) if shouldSelectChannel { channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) if err != nil { @@ -147,13 +83,82 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) return } - SetupContextForSelectedChannel(c, channel, modelRequest.Model) } } + SetupContextForSelectedChannel(c, channel, modelRequest.Model) c.Next() } } +func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { + var modelRequest ModelRequest + shouldSelectChannel := true + var err error + if strings.Contains(c.Request.URL.Path, "/mj/") { + relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) + if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || + relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || + relayMode == relayconstant.RelayModeMidjourneyNotify || + relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { + shouldSelectChannel = false + } else { + midjourneyRequest := dto.MidjourneyRequest{} + err = common.UnmarshalBodyReusable(c, &midjourneyRequest) + if err != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) + return nil, false, err + } + midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) + if mjErr != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) + return nil, false, fmt.Errorf(mjErr.Description) + } + if midjourneyModel == "" { + if !success { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") + return nil, false, fmt.Errorf("无效的请求, 无法解析模型") + } else { + // task fetch, task fetch by condition, notify + shouldSelectChannel = false + } + } + modelRequest.Model = midjourneyModel + } + c.Set("relay_mode", relayMode) + } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } + if err != nil { + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + return nil, false, err + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if modelRequest.Model == "" { + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = "tts-1" + } else { + modelRequest.Model = "whisper-1" + } + } + } + return &modelRequest, shouldSelectChannel, nil +} + func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) From a7cfce24d0cc30753e1680c420d456284b03cea1 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 7 Apr 2024 22:22:27 +0800 Subject: [PATCH 07/11] feat: automatically ban channels that exceeded quota --- service/channel.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/service/channel.go b/service/channel.go index 6ce444d..82ffd77 100644 --- a/service/channel.go +++ b/service/channel.go @@ -57,6 +57,8 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool { return true } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { return true + } else if strings.HasPrefix(err.Message, "You exceeded your current quota") { + return true } return false } From c5f6d0e06370abcce7dbf3a3fbdb5beafa157f9a Mon Sep 17 00:00:00 2001 From: h1xy <48129611+h1xy@users.noreply.github.com> Date: Mon, 8 Apr 2024 02:12:47 +0800 Subject: [PATCH 08/11] Fix: CompletionRatio is not working for openrouter.ai https://openrouter.ai/docs#models Model name of openrouter is prefix with company name, e.g. "model": "anthropic/claude-3-opus:beta", therefore, CompletionRatio will not working for it which is only work for prefix with claude-xxx --- common/model-ratio.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index f618af8..2ff1aa4 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -206,11 +206,11 @@ func GetCompletionRatio(name string) float64 { } return 2 } - if strings.HasPrefix(name, "claude-instant-1") { + if strings.Contains(name, "claude-instant-1") { return 3 - } else if strings.HasPrefix(name, "claude-2") { + } else if strings.Contains(name, "claude-2") { return 3 - } else if strings.HasPrefix(name, "claude-3") { + } else if strings.Contains(name, "claude-3") { return 5 } if strings.HasPrefix(name, "mistral-") { From 60d7ed3fb5cf7da842e97eb1b85dcf397de75d37 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 8 Apr 2024 13:48:36 +0800 Subject: [PATCH 09/11] fix: distributor panic --- middleware/distributor.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index e922662..108c783 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -160,6 +160,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + c.Set("original_model", modelName) // for retry + if channel == nil { + return + } c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) @@ -173,7 +177,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode } c.Set("auto_ban", ban) c.Set("model_mapping", channel.GetModelMapping()) - c.Set("original_model", modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) // TODO: api_version统一 From 2d849e0dd63e10d00c376967ed1d3f85ad1c814e Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 8 Apr 2024 14:10:09 +0800 Subject: [PATCH 10/11] =?UTF-8?q?fix:=20307=E6=9C=AC=E5=9C=B0=E9=87=8D?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/controller/relay.go b/controller/relay.go index c6d850d..0fd9d7a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -92,6 +92,9 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt if openaiErr.StatusCode == http.StatusTooManyRequests { return true } + if openaiErr.StatusCode == 307 { + return true + } if openaiErr.StatusCode/100 == 5 { // 超时不重试 if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 { From 320da09f364c03def78ba9539ed72fcfd3591091 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Mon, 8 Apr 2024 23:51:51 +0800 Subject: [PATCH 11/11] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=B8=A0?= =?UTF-8?q?=E9=81=93=E4=B8=80=E6=AC=A1=E6=80=A7=E6=B7=BB=E5=8A=A0=E5=BE=88?= =?UTF-8?q?=E5=A4=9Amodel=E5=A4=B1=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复渠道一次性添加很多model并且group多 提示失败 too many SQL variables --- model/ability.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/model/ability.go b/model/ability.go index 01fea9e..7fd52bc 100644 --- a/model/ability.go +++ b/model/ability.go @@ -3,6 +3,7 @@ package model import ( "errors" "fmt" + "github.com/samber/lo" "gorm.io/gorm" "one-api/common" "strings" @@ -134,7 +135,16 @@ func (channel *Channel) AddAbilities() error { abilities = append(abilities, ability) } } - return DB.Create(&abilities).Error + if len(abilities) == 0 { + return nil + } + for _, chunk := range lo.Chunk(abilities, 50) { + err := DB.Create(&chunk).Error + if err != nil { + return err + } + } + return nil } func (channel *Channel) DeleteAbilities() error {