Merge remote-tracking branch 'origin/main'

# Conflicts:
#	controller/relay.go
#	main.go
#	middleware/distributor.go
This commit is contained in:
CaIon
2023-09-09 03:15:55 +08:00
26 changed files with 1027 additions and 187 deletions

View File

@@ -2,7 +2,6 @@ package middleware
import (
"fmt"
"log"
"net/http"
"one-api/common"
"one-api/model"
@@ -22,7 +21,6 @@ func Distribute() func(c *gin.Context) {
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel
var err error
channelId, ok := c.Get("channelId")
if ok {
id, err := strconv.Atoi(channelId.(string))
@@ -58,7 +56,6 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
@@ -79,7 +76,6 @@ func Distribute() func(c *gin.Context) {
return
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
@@ -95,6 +91,11 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
@@ -118,8 +119,13 @@ func Distribute() func(c *gin.Context) {
c.Set("model_mapping", channel.ModelMapping)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
}
c.Next()
}