From 462c328d4bc33e927c1ac0de5b709f1274492f06 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Sat, 6 Apr 2024 20:45:18 +0800
Subject: [PATCH 01/11] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=9C=AA?=
=?UTF-8?q?=E5=BC=80=E5=90=AF=E7=BC=93=E5=AD=98=E4=B8=8B=E6=9C=AC=E5=9C=B0?=
=?UTF-8?q?=E9=87=8D=E8=AF=95?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
model/ability.go | 52 +++++++++++++++++++++++++++++++++++++++++++++---
model/cache.go | 2 +-
2 files changed, 50 insertions(+), 4 deletions(-)
diff --git a/model/ability.go b/model/ability.go
index f522967..01fea9e 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -3,6 +3,7 @@ package model
import (
"errors"
"fmt"
+ "gorm.io/gorm"
"one-api/common"
"strings"
)
@@ -27,8 +28,7 @@ func GetGroupModels(group string) []string {
return models
}
-func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
- var abilities []Ability
+func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
@@ -36,9 +36,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
trueVal = "true"
}
- var err error = nil
+ var priorities []int
+ err := DB.Model(&Ability{}).
+ Select("DISTINCT(priority)").
+ Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
+ Order("priority DESC"). // 按优先级降序排序
+ Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
+
+ if err != nil {
+ // 处理错误
+ return 0, err
+ }
+
+ // 确定要使用的优先级
+ var priorityToUse int
+ if retry >= len(priorities) {
+ // 如果重试次数大于优先级数,则使用最小的优先级
+ priorityToUse = priorities[len(priorities)-1]
+ } else {
+ priorityToUse = priorities[retry]
+ }
+ return priorityToUse, nil
+}
+
+func getChannelQuery(group string, model string, retry int) *gorm.DB {
+ groupCol := "`group`"
+ trueVal := "1"
+ if common.UsingPostgreSQL {
+ groupCol = `"group"`
+ trueVal = "true"
+ }
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+ if retry != 0 {
+ priority, err := getPriority(group, model, retry)
+ if err != nil {
+ common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
+ } else {
+ channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
+ }
+ }
+
+ return channelQuery
+}
+
+func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+ var abilities []Ability
+
+ var err error = nil
+ channelQuery := getChannelQuery(group, model, retry)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
diff --git a/model/cache.go b/model/cache.go
index 78bdc17..dc2ed3b 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -272,7 +272,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
- return GetRandomSatisfiedChannel(group, model)
+ return GetRandomSatisfiedChannel(group, model, retry)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
From 497cc32634fe75880e86708a1920b7d5fba297f3 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Sat, 6 Apr 2024 20:47:03 +0800
Subject: [PATCH 02/11] update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index abb2379..7c878af 100644
--- a/README.md
+++ b/README.md
@@ -60,7 +60,7 @@
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
## 渠道重试
-渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试。
+渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
From fbdb17022ca42e0665065f72d992f993bc878ca5 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Sat, 6 Apr 2024 20:49:19 +0800
Subject: [PATCH 03/11] update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 7c878af..41cbec0 100644
--- a/README.md
+++ b/README.md
@@ -61,7 +61,7 @@
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
-如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
+如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
From 5961de03e73cc1eb7de12c02a9a7b2aad9e29c28 Mon Sep 17 00:00:00 2001
From: iszcz <74706321+iszcz@users.noreply.github.com>
Date: Sat, 6 Apr 2024 22:59:23 +0800
Subject: [PATCH 04/11] =?UTF-8?q?=E6=B8=85=E9=99=A4--mode?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
constant/midjourney.go | 2 ++
model/option.go | 3 +++
service/midjourney.go | 9 +++++++++
web/src/components/OperationSetting.js | 7 +++++++
4 files changed, 21 insertions(+)
diff --git a/constant/midjourney.go b/constant/midjourney.go
index 8b88a44..6d0b5ac 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -2,6 +2,8 @@ package constant
var MjNotifyEnabled = false
+var MjModeClearEnabled = false
+
const (
MjErrorUnknown = 5
MjRequestError = 4
diff --git a/model/option.go b/model/option.go
index 057d3b7..8432141 100644
--- a/model/option.go
+++ b/model/option.go
@@ -92,6 +92,7 @@ func InitOptionMap() {
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
+ common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
@@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) {
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue
+ case "MjModeClearEnabled":
+ constant.MjModeClearEnabled = boolValue
case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
diff --git a/service/midjourney.go b/service/midjourney.go
index ae13464..ccf5141 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -172,6 +172,15 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// make new request with mapResult
}
+ if constant.MjModeClearEnabled {
+ if prompt, ok := mapResult["prompt"].(string); ok {
+ prompt = strings.Replace(prompt, "--fast", "", -1)
+ prompt = strings.Replace(prompt, "--relax", "", -1)
+ prompt = strings.Replace(prompt, "--turbo", "", -1)
+
+ mapResult["prompt"] = prompt
+ }
+ }
reqBody, err := json.Marshal(mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index f42fe57..14b70d7 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -36,6 +36,7 @@ const OperationSetting = () => {
StopOnSensitiveEnabled: '',
SensitiveWords: '',
MjNotifyEnabled: '',
+ MjModeClearEnabled: '',
DrawingEnabled: '',
DataExportEnabled: '',
DataExportDefaultTime: 'hour',
@@ -312,6 +313,12 @@ const OperationSetting = () => {
name='MjNotifyEnabled'
onChange={handleInputChange}
/>
+
From 2d1d1b463190ffec0d7116ac94dafcc5d28f94e8 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Sun, 7 Apr 2024 14:42:03 +0800
Subject: [PATCH 05/11] update go-epay
---
controller/topup.go | 5 +++--
go.mod | 8 ++++----
go.sum | 10 ++++++++++
3 files changed, 17 insertions(+), 6 deletions(-)
diff --git a/controller/topup.go b/controller/topup.go
index e938a15..08493f9 100644
--- a/controller/topup.go
+++ b/controller/topup.go
@@ -2,9 +2,10 @@ package controller
import (
"fmt"
+ "github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
- epay "github.com/star-horizon/go-epay"
+
"log"
"net/url"
"one-api/common"
@@ -30,7 +31,7 @@ func GetEpayClient() *epay.Client {
if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" {
return nil
}
- withUrl, err := epay.NewClientWithUrl(&epay.Config{
+ withUrl, err := epay.NewClient(&epay.Config{
PartnerID: common.EpayId,
Key: common.EpayKey,
}, common.PayAddress)
diff --git a/go.mod b/go.mod
index b0c7220..62bc80e 100644
--- a/go.mod
+++ b/go.mod
@@ -4,6 +4,7 @@ module one-api
go 1.18
require (
+ github.com/Calcium-Ion/go-epay v0.0.2
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
@@ -16,9 +17,8 @@ require (
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.6
- github.com/samber/lo v1.38.1
+ github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
- github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.21.0
golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.4.3
@@ -65,9 +65,9 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect
- golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
+ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.21.0 // indirect
- golang.org/x/sync v0.1.0 // indirect
+ golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
diff --git a/go.sum b/go.sum
index 5b17b48..5bb3189 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,7 @@
+github.com/Calcium-Ion/go-epay v0.0.1 h1:cRCvwNTkPmmLM5od0p4w0cTcYcAPaAVLYr41ujseDcc=
+github.com/Calcium-Ion/go-epay v0.0.1/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
+github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
+github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
@@ -137,6 +141,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
+github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
+github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
@@ -175,6 +181,8 @@ golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
+golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
+golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -182,6 +190,8 @@ golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
+golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
From 34bf8f8945e4c7952fbdd0802a4f099c312e2f63 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Sun, 7 Apr 2024 22:08:11 +0800
Subject: [PATCH 06/11] fix: select channel
---
middleware/distributor.go | 141 ++++++++++++++++++++------------------
1 file changed, 73 insertions(+), 68 deletions(-)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 35cb6df..e922662 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -24,6 +24,9 @@ func Distribute() func(c *gin.Context) {
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
+ modelRequest, shouldSelectChannel, err := getModelRequest(c)
+ userGroup, _ := model.CacheGetUserGroup(userId)
+ c.Set("group", userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -40,72 +43,7 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
- shouldSelectChannel := true
// Select a channel for the user
- var modelRequest ModelRequest
- var err error
- if strings.Contains(c.Request.URL.Path, "/mj/") {
- relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
- if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
- relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
- relayMode == relayconstant.RelayModeMidjourneyNotify ||
- relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
- shouldSelectChannel = false
- } else {
- midjourneyRequest := dto.MidjourneyRequest{}
- err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
- if err != nil {
- abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
- return
- }
- midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
- if mjErr != nil {
- abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
- return
- }
- if midjourneyModel == "" {
- if !success {
- abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
- return
- } else {
- // task fetch, task fetch by condition, notify
- shouldSelectChannel = false
- }
- }
- modelRequest.Model = midjourneyModel
- }
- c.Set("relay_mode", relayMode)
- } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
- err = common.UnmarshalBodyReusable(c, &modelRequest)
- }
- if err != nil {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
- return
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "text-moderation-stable"
- }
- }
- if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
- if modelRequest.Model == "" {
- modelRequest.Model = c.Param("model")
- }
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "dall-e"
- }
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
- if modelRequest.Model == "" {
- if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
- modelRequest.Model = "tts-1"
- } else {
- modelRequest.Model = "whisper-1"
- }
- }
- }
// check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
@@ -128,8 +66,6 @@ func Distribute() func(c *gin.Context) {
}
}
- userGroup, _ := model.CacheGetUserGroup(userId)
- c.Set("group", userGroup)
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil {
@@ -147,13 +83,82 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
- SetupContextForSelectedChannel(c, channel, modelRequest.Model)
}
}
+ SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next()
}
}
+func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
+ var modelRequest ModelRequest
+ shouldSelectChannel := true
+ var err error
+ if strings.Contains(c.Request.URL.Path, "/mj/") {
+ relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
+ relayMode == relayconstant.RelayModeMidjourneyNotify ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
+ shouldSelectChannel = false
+ } else {
+ midjourneyRequest := dto.MidjourneyRequest{}
+ err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
+ if err != nil {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
+ return nil, false, err
+ }
+ midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
+ if mjErr != nil {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
+ return nil, false, fmt.Errorf(mjErr.Description)
+ }
+ if midjourneyModel == "" {
+ if !success {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
+ return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
+ } else {
+ // task fetch, task fetch by condition, notify
+ shouldSelectChannel = false
+ }
+ }
+ modelRequest.Model = midjourneyModel
+ }
+ c.Set("relay_mode", relayMode)
+ } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ }
+ if err != nil {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
+ return nil, false, err
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "text-moderation-stable"
+ }
+ }
+ if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = c.Param("model")
+ }
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "dall-e"
+ }
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+ if modelRequest.Model == "" {
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+ modelRequest.Model = "tts-1"
+ } else {
+ modelRequest.Model = "whisper-1"
+ }
+ }
+ }
+ return &modelRequest, shouldSelectChannel, nil
+}
+
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
From a7cfce24d0cc30753e1680c420d456284b03cea1 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Sun, 7 Apr 2024 22:22:27 +0800
Subject: [PATCH 07/11] feat: automatically ban channels that exceeded quota
---
service/channel.go | 2 ++
1 file changed, 2 insertions(+)
diff --git a/service/channel.go b/service/channel.go
index 6ce444d..82ffd77 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -57,6 +57,8 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
+ } else if strings.HasPrefix(err.Message, "You exceeded your current quota") {
+ return true
}
return false
}
From c5f6d0e06370abcce7dbf3a3fbdb5beafa157f9a Mon Sep 17 00:00:00 2001
From: h1xy <48129611+h1xy@users.noreply.github.com>
Date: Mon, 8 Apr 2024 02:12:47 +0800
Subject: [PATCH 08/11] Fix: CompletionRatio is not working for openrouter.ai
https://openrouter.ai/docs#models
Model name of openrouter is prefix with company name, e.g. "model": "anthropic/claude-3-opus:beta", therefore, CompletionRatio will not working for it which is only work for prefix with claude-xxx
---
common/model-ratio.go | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index f618af8..2ff1aa4 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -206,11 +206,11 @@ func GetCompletionRatio(name string) float64 {
}
return 2
}
- if strings.HasPrefix(name, "claude-instant-1") {
+ if strings.Contains(name, "claude-instant-1") {
return 3
- } else if strings.HasPrefix(name, "claude-2") {
+ } else if strings.Contains(name, "claude-2") {
return 3
- } else if strings.HasPrefix(name, "claude-3") {
+ } else if strings.Contains(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "mistral-") {
From 60d7ed3fb5cf7da842e97eb1b85dcf397de75d37 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Mon, 8 Apr 2024 13:48:36 +0800
Subject: [PATCH 09/11] fix: distributor panic
---
middleware/distributor.go | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index e922662..108c783 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -160,6 +160,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
+ c.Set("original_model", modelName) // for retry
+ if channel == nil {
+ return
+ }
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
@@ -173,7 +177,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
- c.Set("original_model", modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
From 2d849e0dd63e10d00c376967ed1d3f85ad1c814e Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Mon, 8 Apr 2024 14:10:09 +0800
Subject: [PATCH 10/11] =?UTF-8?q?fix:=20307=E6=9C=AC=E5=9C=B0=E9=87=8D?=
=?UTF-8?q?=E8=AF=95?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
controller/relay.go | 3 +++
1 file changed, 3 insertions(+)
diff --git a/controller/relay.go b/controller/relay.go
index c6d850d..0fd9d7a 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -92,6 +92,9 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
if openaiErr.StatusCode == http.StatusTooManyRequests {
return true
}
+ if openaiErr.StatusCode == 307 {
+ return true
+ }
if openaiErr.StatusCode/100 == 5 {
// 超时不重试
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
From 320da09f364c03def78ba9539ed72fcfd3591091 Mon Sep 17 00:00:00 2001
From: Xyfacai
Date: Mon, 8 Apr 2024 23:51:51 +0800
Subject: [PATCH 11/11] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=B8=A0?=
=?UTF-8?q?=E9=81=93=E4=B8=80=E6=AC=A1=E6=80=A7=E6=B7=BB=E5=8A=A0=E5=BE=88?=
=?UTF-8?q?=E5=A4=9Amodel=E5=A4=B1=E8=B4=A5?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
修复渠道一次性添加很多model并且group多
提示失败 too many SQL variables
---
model/ability.go | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/model/ability.go b/model/ability.go
index 01fea9e..7fd52bc 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -3,6 +3,7 @@ package model
import (
"errors"
"fmt"
+ "github.com/samber/lo"
"gorm.io/gorm"
"one-api/common"
"strings"
@@ -134,7 +135,16 @@ func (channel *Channel) AddAbilities() error {
abilities = append(abilities, ability)
}
}
- return DB.Create(&abilities).Error
+ if len(abilities) == 0 {
+ return nil
+ }
+ for _, chunk := range lo.Chunk(abilities, 50) {
+ err := DB.Create(&chunk).Error
+ if err != nil {
+ return err
+ }
+ }
+ return nil
}
func (channel *Channel) DeleteAbilities() error {