merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong
2024-07-19 10:58:21 +08:00
72 changed files with 1989 additions and 1193 deletions

View File

@@ -1,6 +1,7 @@
package middleware
import (
"errors"
"fmt"
"net/http"
"one-api/common"
@@ -25,6 +26,10 @@ func Distribute() func(c *gin.Context) {
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if ok {
@@ -141,7 +146,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, err
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
@@ -154,18 +159,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
relayMode := relayconstant.RelayModeAudioSpeech
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranscription
}
c.Set("relay_mode", relayMode)
}
return &modelRequest, shouldSelectChannel, nil
}
@@ -198,11 +207,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
}
}