mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 21:33:41 +08:00 
			
		
		
		
	merge upstream
Signed-off-by: wozulong <>
This commit is contained in:
		@@ -60,8 +60,8 @@
 | 
			
		||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
 | 
			
		||||
 | 
			
		||||
## 渠道重试
 | 
			
		||||
渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试。  
 | 
			
		||||
如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。  
 | 
			
		||||
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。  
 | 
			
		||||
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。  
 | 
			
		||||
### 缓存设置方法
 | 
			
		||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
 | 
			
		||||
    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,8 @@ package constant
 | 
			
		||||
 | 
			
		||||
var MjNotifyEnabled = false
 | 
			
		||||
 | 
			
		||||
var MjModeClearEnabled = false
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	MjErrorUnknown = 5
 | 
			
		||||
	MjRequestError = 4
 | 
			
		||||
 
 | 
			
		||||
@@ -92,6 +92,9 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
 | 
			
		||||
	if openaiErr.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if openaiErr.StatusCode == 307 {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if openaiErr.StatusCode/100 == 5 {
 | 
			
		||||
		// 超时不重试
 | 
			
		||||
		if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										10
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.mod
									
									
									
									
									
								
							@@ -16,9 +16,9 @@ require (
 | 
			
		||||
	github.com/google/uuid v1.3.0
 | 
			
		||||
	github.com/gorilla/websocket v1.5.0
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.6
 | 
			
		||||
	github.com/samber/lo v1.38.1
 | 
			
		||||
	github.com/samber/lo v1.39.0
 | 
			
		||||
	github.com/shirou/gopsutil v3.21.11+incompatible
 | 
			
		||||
	github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
 | 
			
		||||
	github.com/stripe/stripe-go/v76 v76.21.0
 | 
			
		||||
	golang.org/x/crypto v0.21.0
 | 
			
		||||
	golang.org/x/image v0.15.0
 | 
			
		||||
	gorm.io/driver/mysql v1.4.3
 | 
			
		||||
@@ -55,20 +55,18 @@ require (
 | 
			
		||||
	github.com/leodido/go-urn v1.4.0 // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.20 // indirect
 | 
			
		||||
	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
 | 
			
		||||
	github.com/mitchellh/mapstructure v1.5.0 // indirect
 | 
			
		||||
	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 | 
			
		||||
	github.com/modern-go/reflect2 v1.0.2 // indirect
 | 
			
		||||
	github.com/pelletier/go-toml/v2 v2.0.8 // indirect
 | 
			
		||||
	github.com/stripe/stripe-go/v76 v76.21.0 // indirect
 | 
			
		||||
	github.com/tklauser/go-sysconf v0.3.12 // indirect
 | 
			
		||||
	github.com/tklauser/numcpus v0.6.1 // indirect
 | 
			
		||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
			
		||||
	github.com/ugorji/go/codec v1.2.11 // indirect
 | 
			
		||||
	github.com/yusufpapurcu/wmi v1.2.3 // indirect
 | 
			
		||||
	golang.org/x/arch v0.3.0 // indirect
 | 
			
		||||
	golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
 | 
			
		||||
	golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
 | 
			
		||||
	golang.org/x/net v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.1.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.7.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.18.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.14.0 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.30.0 // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										19
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								go.sum
									
									
									
									
									
								
							@@ -62,8 +62,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
 | 
			
		||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
 | 
			
		||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
 | 
			
		||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 | 
			
		||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
 | 
			
		||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
 | 
			
		||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 | 
			
		||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
 | 
			
		||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
@@ -113,8 +113,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
 | 
			
		||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
 | 
			
		||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
 | 
			
		||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
 | 
			
		||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
 | 
			
		||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
 | 
			
		||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 | 
			
		||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
 | 
			
		||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 | 
			
		||||
@@ -135,12 +133,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
 | 
			
		||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
 | 
			
		||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
 | 
			
		||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
 | 
			
		||||
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
 | 
			
		||||
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
 | 
			
		||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
 | 
			
		||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
 | 
			
		||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
 | 
			
		||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
 | 
			
		||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
 | 
			
		||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo=
 | 
			
		||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
			
		||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 | 
			
		||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 | 
			
		||||
@@ -175,16 +171,16 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
 | 
			
		||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
 | 
			
		||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
 | 
			
		||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
 | 
			
		||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
 | 
			
		||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
 | 
			
		||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
 | 
			
		||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
 | 
			
		||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 | 
			
		||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
 | 
			
		||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
 | 
			
		||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
 | 
			
		||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
 | 
			
		||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
 | 
			
		||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
@@ -206,7 +202,6 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
 | 
			
		||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
 | 
			
		||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
 | 
			
		||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 | 
			
		||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 | 
			
		||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,9 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		userId := c.GetInt("id")
 | 
			
		||||
		var channel *model.Channel
 | 
			
		||||
		channelId, ok := c.Get("specific_channel_id")
 | 
			
		||||
		modelRequest, shouldSelectChannel, err := getModelRequest(c)
 | 
			
		||||
		userGroup, _ := model.CacheGetUserGroup(userId)
 | 
			
		||||
		c.Set("group", userGroup)
 | 
			
		||||
		if ok {
 | 
			
		||||
			id, err := strconv.Atoi(channelId.(string))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -40,72 +43,7 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			shouldSelectChannel := true
 | 
			
		||||
			// Select a channel for the user
 | 
			
		||||
			var modelRequest ModelRequest
 | 
			
		||||
			var err error
 | 
			
		||||
			if strings.Contains(c.Request.URL.Path, "/mj/") {
 | 
			
		||||
				relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
 | 
			
		||||
				if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
 | 
			
		||||
					relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
 | 
			
		||||
					relayMode == relayconstant.RelayModeMidjourneyNotify ||
 | 
			
		||||
					relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
 | 
			
		||||
					shouldSelectChannel = false
 | 
			
		||||
				} else {
 | 
			
		||||
					midjourneyRequest := dto.MidjourneyRequest{}
 | 
			
		||||
					err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
 | 
			
		||||
					if mjErr != nil {
 | 
			
		||||
						abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					if midjourneyModel == "" {
 | 
			
		||||
						if !success {
 | 
			
		||||
							abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
 | 
			
		||||
							return
 | 
			
		||||
						} else {
 | 
			
		||||
							// task fetch, task fetch by condition, notify
 | 
			
		||||
							shouldSelectChannel = false
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					modelRequest.Model = midjourneyModel
 | 
			
		||||
				}
 | 
			
		||||
				c.Set("relay_mode", relayMode)
 | 
			
		||||
			} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 | 
			
		||||
				err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "text-moderation-stable"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = c.Param("model")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "dall-e"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
 | 
			
		||||
						modelRequest.Model = "tts-1"
 | 
			
		||||
					} else {
 | 
			
		||||
						modelRequest.Model = "whisper-1"
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// check token model mapping
 | 
			
		||||
			modelLimitEnable := c.GetBool("token_model_limit_enabled")
 | 
			
		||||
			if modelLimitEnable {
 | 
			
		||||
@@ -128,8 +66,6 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			userGroup, _ := model.CacheGetUserGroup(userId)
 | 
			
		||||
			c.Set("group", userGroup)
 | 
			
		||||
			if shouldSelectChannel {
 | 
			
		||||
				channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
@@ -147,14 +83,87 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				SetupContextForSelectedChannel(c, channel, modelRequest.Model)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		SetupContextForSelectedChannel(c, channel, modelRequest.Model)
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 | 
			
		||||
	var modelRequest ModelRequest
 | 
			
		||||
	shouldSelectChannel := true
 | 
			
		||||
	var err error
 | 
			
		||||
	if strings.Contains(c.Request.URL.Path, "/mj/") {
 | 
			
		||||
		relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
 | 
			
		||||
		if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
 | 
			
		||||
			relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
 | 
			
		||||
			relayMode == relayconstant.RelayModeMidjourneyNotify ||
 | 
			
		||||
			relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
 | 
			
		||||
			shouldSelectChannel = false
 | 
			
		||||
		} else {
 | 
			
		||||
			midjourneyRequest := dto.MidjourneyRequest{}
 | 
			
		||||
			err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
 | 
			
		||||
				return nil, false, err
 | 
			
		||||
			}
 | 
			
		||||
			midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
 | 
			
		||||
			if mjErr != nil {
 | 
			
		||||
				abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
 | 
			
		||||
				return nil, false, fmt.Errorf(mjErr.Description)
 | 
			
		||||
			}
 | 
			
		||||
			if midjourneyModel == "" {
 | 
			
		||||
				if !success {
 | 
			
		||||
					abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
 | 
			
		||||
					return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
 | 
			
		||||
				} else {
 | 
			
		||||
					// task fetch, task fetch by condition, notify
 | 
			
		||||
					shouldSelectChannel = false
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			modelRequest.Model = midjourneyModel
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("relay_mode", relayMode)
 | 
			
		||||
	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 | 
			
		||||
		err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
 | 
			
		||||
		return nil, false, err
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
		if modelRequest.Model == "" {
 | 
			
		||||
			modelRequest.Model = "text-moderation-stable"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
 | 
			
		||||
		if modelRequest.Model == "" {
 | 
			
		||||
			modelRequest.Model = c.Param("model")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | 
			
		||||
		if modelRequest.Model == "" {
 | 
			
		||||
			modelRequest.Model = "dall-e"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
		if modelRequest.Model == "" {
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
 | 
			
		||||
				modelRequest.Model = "tts-1"
 | 
			
		||||
			} else {
 | 
			
		||||
				modelRequest.Model = "whisper-1"
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return &modelRequest, shouldSelectChannel, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
 | 
			
		||||
	c.Set("original_model", modelName) // for retry
 | 
			
		||||
	if channel == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.Set("channel", channel.Type)
 | 
			
		||||
	c.Set("channel_id", channel.Id)
 | 
			
		||||
	c.Set("channel_name", channel.Name)
 | 
			
		||||
@@ -168,7 +177,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 | 
			
		||||
	}
 | 
			
		||||
	c.Set("auto_ban", ban)
 | 
			
		||||
	c.Set("model_mapping", channel.GetModelMapping())
 | 
			
		||||
	c.Set("original_model", modelName) // for retry
 | 
			
		||||
	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
	c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
	// TODO: api_version统一
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,8 @@ package model
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/samber/lo"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
@@ -27,8 +29,7 @@ func GetGroupModels(group string) []string {
 | 
			
		||||
	return models
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
			
		||||
	var abilities []Ability
 | 
			
		||||
func getPriority(group string, model string, retry int) (int, error) {
 | 
			
		||||
	groupCol := "`group`"
 | 
			
		||||
	trueVal := "1"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
@@ -36,9 +37,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
			
		||||
		trueVal = "true"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var err error = nil
 | 
			
		||||
	var priorities []int
 | 
			
		||||
	err := DB.Model(&Ability{}).
 | 
			
		||||
		Select("DISTINCT(priority)").
 | 
			
		||||
		Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
 | 
			
		||||
		Order("priority DESC").              // 按优先级降序排序
 | 
			
		||||
		Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// 处理错误
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 确定要使用的优先级
 | 
			
		||||
	var priorityToUse int
 | 
			
		||||
	if retry >= len(priorities) {
 | 
			
		||||
		// 如果重试次数大于优先级数,则使用最小的优先级
 | 
			
		||||
		priorityToUse = priorities[len(priorities)-1]
 | 
			
		||||
	} else {
 | 
			
		||||
		priorityToUse = priorities[retry]
 | 
			
		||||
	}
 | 
			
		||||
	return priorityToUse, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
 | 
			
		||||
	groupCol := "`group`"
 | 
			
		||||
	trueVal := "1"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		groupCol = `"group"`
 | 
			
		||||
		trueVal = "true"
 | 
			
		||||
	}
 | 
			
		||||
	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
 | 
			
		||||
	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
 | 
			
		||||
	if retry != 0 {
 | 
			
		||||
		priority, err := getPriority(group, model, retry)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
 | 
			
		||||
		} else {
 | 
			
		||||
			channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return channelQuery
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
 | 
			
		||||
	var abilities []Ability
 | 
			
		||||
 | 
			
		||||
	var err error = nil
 | 
			
		||||
	channelQuery := getChannelQuery(group, model, retry)
 | 
			
		||||
	if common.UsingSQLite || common.UsingPostgreSQL {
 | 
			
		||||
		err = channelQuery.Order("weight DESC").Find(&abilities).Error
 | 
			
		||||
	} else {
 | 
			
		||||
@@ -88,7 +135,16 @@ func (channel *Channel) AddAbilities() error {
 | 
			
		||||
			abilities = append(abilities, ability)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return DB.Create(&abilities).Error
 | 
			
		||||
	if len(abilities) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	for _, chunk := range lo.Chunk(abilities, 50) {
 | 
			
		||||
		err := DB.Create(&chunk).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) DeleteAbilities() error {
 | 
			
		||||
 
 | 
			
		||||
@@ -296,7 +296,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
 | 
			
		||||
 | 
			
		||||
	// if memory cache is disabled, get channel directly from database
 | 
			
		||||
	if !common.MemoryCacheEnabled {
 | 
			
		||||
		return GetRandomSatisfiedChannel(group, model)
 | 
			
		||||
		return GetRandomSatisfiedChannel(group, model, retry)
 | 
			
		||||
	}
 | 
			
		||||
	channelSyncLock.RLock()
 | 
			
		||||
	defer channelSyncLock.RUnlock()
 | 
			
		||||
 
 | 
			
		||||
@@ -97,6 +97,7 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
 | 
			
		||||
	common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
 | 
			
		||||
	common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
 | 
			
		||||
	common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
 | 
			
		||||
	common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
 | 
			
		||||
	common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
 | 
			
		||||
	//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
 | 
			
		||||
@@ -204,6 +205,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
			common.DefaultCollapseSidebar = boolValue
 | 
			
		||||
		case "MjNotifyEnabled":
 | 
			
		||||
			constant.MjNotifyEnabled = boolValue
 | 
			
		||||
		case "MjModeClearEnabled":
 | 
			
		||||
			constant.MjModeClearEnabled = boolValue
 | 
			
		||||
		case "CheckSensitiveEnabled":
 | 
			
		||||
			constant.CheckSensitiveEnabled = boolValue
 | 
			
		||||
		case "CheckSensitiveOnPromptEnabled":
 | 
			
		||||
 
 | 
			
		||||
@@ -57,6 +57,8 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
 | 
			
		||||
		return true
 | 
			
		||||
	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
 | 
			
		||||
		return true
 | 
			
		||||
	} else if strings.HasPrefix(err.Message, "You exceeded your current quota") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -172,6 +172,15 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
 | 
			
		||||
		//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
		// make new request with mapResult
 | 
			
		||||
	}
 | 
			
		||||
	if constant.MjModeClearEnabled {
 | 
			
		||||
		if prompt, ok := mapResult["prompt"].(string); ok {
 | 
			
		||||
		    prompt = strings.Replace(prompt, "--fast", "", -1)
 | 
			
		||||
		    prompt = strings.Replace(prompt, "--relax", "", -1)
 | 
			
		||||
		    prompt = strings.Replace(prompt, "--turbo", "", -1)
 | 
			
		||||
		    
 | 
			
		||||
		    mapResult["prompt"] = prompt
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	reqBody, err := json.Marshal(mapResult)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
 | 
			
		||||
 
 | 
			
		||||
@@ -36,6 +36,7 @@ const OperationSetting = () => {
 | 
			
		||||
    StopOnSensitiveEnabled: '',
 | 
			
		||||
    SensitiveWords: '',
 | 
			
		||||
    MjNotifyEnabled: '',
 | 
			
		||||
    MjModeClearEnabled: '',
 | 
			
		||||
    DrawingEnabled: '',
 | 
			
		||||
    DataExportEnabled: '',
 | 
			
		||||
    DataExportDefaultTime: 'hour',
 | 
			
		||||
@@ -312,6 +313,12 @@ const OperationSetting = () => {
 | 
			
		||||
              name='MjNotifyEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.MjModeClearEnabled === 'true'}
 | 
			
		||||
              label='开启之后会清除用户提示词中的--fast、--relax以及--turbo参数'
 | 
			
		||||
              name='MjModeClearEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>屏蔽词过滤设置</Header>
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user