diff --git a/README.md b/README.md index abb2379..41cbec0 100644 --- a/README.md +++ b/README.md @@ -60,8 +60,8 @@ 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 ## 渠道重试 -渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试。 -如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 +渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。 +如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 ### 缓存设置方法 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` diff --git a/constant/midjourney.go b/constant/midjourney.go index 8b88a44..6d0b5ac 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -2,6 +2,8 @@ package constant var MjNotifyEnabled = false +var MjModeClearEnabled = false + const ( MjErrorUnknown = 5 MjRequestError = 4 diff --git a/controller/relay.go b/controller/relay.go index c6d850d..0fd9d7a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -92,6 +92,9 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt if openaiErr.StatusCode == http.StatusTooManyRequests { return true } + if openaiErr.StatusCode == 307 { + return true + } if openaiErr.StatusCode/100 == 5 { // 超时不重试 if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 { diff --git a/go.mod b/go.mod index 2ecbc52..958700a 100644 --- a/go.mod +++ b/go.mod @@ -16,9 +16,9 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.6 - github.com/samber/lo v1.38.1 + github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible - github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 + github.com/stripe/stripe-go/v76 v76.21.0 golang.org/x/crypto v0.21.0 golang.org/x/image v0.15.0 gorm.io/driver/mysql v1.4.3 @@ -55,20 +55,18 @@ require ( github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect - github.com/stripe/stripe-go/v76 v76.21.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect + golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/net v0.21.0 // indirect - golang.org/x/sync v0.1.0 // indirect + golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect diff --git a/go.sum b/go.sum index d498440..3f8223c 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -113,8 +113,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= -github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -135,12 +133,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= -github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= +github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= -github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI= -github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -175,16 +171,16 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= -golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= +golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -206,7 +202,6 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= diff --git a/middleware/distributor.go b/middleware/distributor.go index 35cb6df..108c783 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -24,6 +24,9 @@ func Distribute() func(c *gin.Context) { userId := c.GetInt("id") var channel *model.Channel channelId, ok := c.Get("specific_channel_id") + modelRequest, shouldSelectChannel, err := getModelRequest(c) + userGroup, _ := model.CacheGetUserGroup(userId) + c.Set("group", userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -40,72 +43,7 @@ func Distribute() func(c *gin.Context) { return } } else { - shouldSelectChannel := true // Select a channel for the user - var modelRequest ModelRequest - var err error - if strings.Contains(c.Request.URL.Path, "/mj/") { - relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) - if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || - relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || - relayMode == relayconstant.RelayModeMidjourneyNotify || - relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { - shouldSelectChannel = false - } else { - midjourneyRequest := dto.MidjourneyRequest{} - err = common.UnmarshalBodyReusable(c, &midjourneyRequest) - if err != nil { - abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) - return - } - midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) - if mjErr != nil { - abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) - return - } - if midjourneyModel == "" { - if !success { - abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") - return - } else { - // task fetch, task fetch by condition, notify - shouldSelectChannel = false - } - } - modelRequest.Model = midjourneyModel - } - c.Set("relay_mode", relayMode) - } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - } - if err != nil { - abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - if modelRequest.Model == "" { - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - modelRequest.Model = "tts-1" - } else { - modelRequest.Model = "whisper-1" - } - } - } // check token model mapping modelLimitEnable := c.GetBool("token_model_limit_enabled") if modelLimitEnable { @@ -128,8 +66,6 @@ func Distribute() func(c *gin.Context) { } } - userGroup, _ := model.CacheGetUserGroup(userId) - c.Set("group", userGroup) if shouldSelectChannel { channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) if err != nil { @@ -147,14 +83,87 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) return } - SetupContextForSelectedChannel(c, channel, modelRequest.Model) } } + SetupContextForSelectedChannel(c, channel, modelRequest.Model) c.Next() } } +func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { + var modelRequest ModelRequest + shouldSelectChannel := true + var err error + if strings.Contains(c.Request.URL.Path, "/mj/") { + relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) + if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || + relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || + relayMode == relayconstant.RelayModeMidjourneyNotify || + relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { + shouldSelectChannel = false + } else { + midjourneyRequest := dto.MidjourneyRequest{} + err = common.UnmarshalBodyReusable(c, &midjourneyRequest) + if err != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) + return nil, false, err + } + midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) + if mjErr != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) + return nil, false, fmt.Errorf(mjErr.Description) + } + if midjourneyModel == "" { + if !success { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") + return nil, false, fmt.Errorf("无效的请求, 无法解析模型") + } else { + // task fetch, task fetch by condition, notify + shouldSelectChannel = false + } + } + modelRequest.Model = midjourneyModel + } + c.Set("relay_mode", relayMode) + } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } + if err != nil { + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + return nil, false, err + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if modelRequest.Model == "" { + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = "tts-1" + } else { + modelRequest.Model = "whisper-1" + } + } + } + return &modelRequest, shouldSelectChannel, nil +} + func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + c.Set("original_model", modelName) // for retry + if channel == nil { + return + } c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) @@ -168,7 +177,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode } c.Set("auto_ban", ban) c.Set("model_mapping", channel.GetModelMapping()) - c.Set("original_model", modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) // TODO: api_version统一 diff --git a/model/ability.go b/model/ability.go index f522967..7fd52bc 100644 --- a/model/ability.go +++ b/model/ability.go @@ -3,6 +3,8 @@ package model import ( "errors" "fmt" + "github.com/samber/lo" + "gorm.io/gorm" "one-api/common" "strings" ) @@ -27,8 +29,7 @@ func GetGroupModels(group string) []string { return models } -func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - var abilities []Ability +func getPriority(group string, model string, retry int) (int, error) { groupCol := "`group`" trueVal := "1" if common.UsingPostgreSQL { @@ -36,9 +37,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { trueVal = "true" } - var err error = nil + var priorities []int + err := DB.Model(&Ability{}). + Select("DISTINCT(priority)"). + Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model). + Order("priority DESC"). // 按优先级降序排序 + Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 + + if err != nil { + // 处理错误 + return 0, err + } + + // 确定要使用的优先级 + var priorityToUse int + if retry >= len(priorities) { + // 如果重试次数大于优先级数,则使用最小的优先级 + priorityToUse = priorities[len(priorities)-1] + } else { + priorityToUse = priorities[retry] + } + return priorityToUse, nil +} + +func getChannelQuery(group string, model string, retry int) *gorm.DB { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + if retry != 0 { + priority, err := getPriority(group, model, retry) + if err != nil { + common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error())) + } else { + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority) + } + } + + return channelQuery +} + +func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { + var abilities []Ability + + var err error = nil + channelQuery := getChannelQuery(group, model, retry) if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("weight DESC").Find(&abilities).Error } else { @@ -88,7 +135,16 @@ func (channel *Channel) AddAbilities() error { abilities = append(abilities, ability) } } - return DB.Create(&abilities).Error + if len(abilities) == 0 { + return nil + } + for _, chunk := range lo.Chunk(abilities, 50) { + err := DB.Create(&chunk).Error + if err != nil { + return err + } + } + return nil } func (channel *Channel) DeleteAbilities() error { diff --git a/model/cache.go b/model/cache.go index 01245c9..f8ac584 100644 --- a/model/cache.go +++ b/model/cache.go @@ -296,7 +296,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha // if memory cache is disabled, get channel directly from database if !common.MemoryCacheEnabled { - return GetRandomSatisfiedChannel(group, model) + return GetRandomSatisfiedChannel(group, model, retry) } channelSyncLock.RLock() defer channelSyncLock.RUnlock() diff --git a/model/option.go b/model/option.go index a1b0738..da6cfcb 100644 --- a/model/option.go +++ b/model/option.go @@ -97,6 +97,7 @@ func InitOptionMap() { common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled) + common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled) //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) @@ -204,6 +205,8 @@ func updateOptionMap(key string, value string) (err error) { common.DefaultCollapseSidebar = boolValue case "MjNotifyEnabled": constant.MjNotifyEnabled = boolValue + case "MjModeClearEnabled": + constant.MjModeClearEnabled = boolValue case "CheckSensitiveEnabled": constant.CheckSensitiveEnabled = boolValue case "CheckSensitiveOnPromptEnabled": diff --git a/service/channel.go b/service/channel.go index 6ce444d..82ffd77 100644 --- a/service/channel.go +++ b/service/channel.go @@ -57,6 +57,8 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool { return true } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { return true + } else if strings.HasPrefix(err.Message, "You exceeded your current quota") { + return true } return false } diff --git a/service/midjourney.go b/service/midjourney.go index ae13464..ccf5141 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -172,6 +172,15 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) // make new request with mapResult } + if constant.MjModeClearEnabled { + if prompt, ok := mapResult["prompt"].(string); ok { + prompt = strings.Replace(prompt, "--fast", "", -1) + prompt = strings.Replace(prompt, "--relax", "", -1) + prompt = strings.Replace(prompt, "--turbo", "", -1) + + mapResult["prompt"] = prompt + } + } reqBody, err := json.Marshal(mapResult) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index f42fe57..14b70d7 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -36,6 +36,7 @@ const OperationSetting = () => { StopOnSensitiveEnabled: '', SensitiveWords: '', MjNotifyEnabled: '', + MjModeClearEnabled: '', DrawingEnabled: '', DataExportEnabled: '', DataExportDefaultTime: 'hour', @@ -312,6 +313,12 @@ const OperationSetting = () => { name='MjNotifyEnabled' onChange={handleInputChange} /> +
屏蔽词过滤设置