This commit is contained in:
ckt1031
2023-08-28 00:35:15 +08:00
13 changed files with 283 additions and 36 deletions

View File

@@ -59,7 +59,10 @@ func Distribute() func(c *gin.Context) {
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
var err error
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
@@ -85,7 +88,12 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream)
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
@@ -108,7 +116,7 @@ func Distribute() func(c *gin.Context) {
c.Set("model_mapping", channel.ModelMapping)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure {
if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
c.Set("api_version", channel.Other)
}
c.Next()