one-api/providers/providers.go
2024-05-29 01:56:14 +08:00

98 lines
3.7 KiB
Go

package providers
import (
"one-api/common/config"
"one-api/model"
"one-api/providers/ali"
"one-api/providers/azure"
azurespeech "one-api/providers/azureSpeech"
"one-api/providers/baichuan"
"one-api/providers/baidu"
"one-api/providers/base"
"one-api/providers/bedrock"
"one-api/providers/claude"
"one-api/providers/cloudflareAI"
"one-api/providers/cohere"
"one-api/providers/coze"
"one-api/providers/deepseek"
"one-api/providers/gemini"
"one-api/providers/groq"
"one-api/providers/hunyuan"
"one-api/providers/lingyi"
"one-api/providers/midjourney"
"one-api/providers/minimax"
"one-api/providers/mistral"
"one-api/providers/moonshot"
"one-api/providers/ollama"
"one-api/providers/openai"
"one-api/providers/palm"
"one-api/providers/stabilityAI"
"one-api/providers/tencent"
"one-api/providers/xunfei"
"one-api/providers/zhipu"
"github.com/gin-gonic/gin"
)
// 定义供应商工厂接口
type ProviderFactory interface {
Create(Channel *model.Channel) base.ProviderInterface
}
// 创建全局的供应商工厂映射
var providerFactories = make(map[int]ProviderFactory)
// 在程序启动时,添加所有的供应商工厂
func init() {
providerFactories[config.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{}
providerFactories[config.ChannelTypeAzure] = azure.AzureProviderFactory{}
providerFactories[config.ChannelTypeAli] = ali.AliProviderFactory{}
providerFactories[config.ChannelTypeTencent] = tencent.TencentProviderFactory{}
providerFactories[config.ChannelTypeBaidu] = baidu.BaiduProviderFactory{}
providerFactories[config.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{}
providerFactories[config.ChannelTypePaLM] = palm.PalmProviderFactory{}
providerFactories[config.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{}
providerFactories[config.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{}
providerFactories[config.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{}
providerFactories[config.ChannelTypeGemini] = gemini.GeminiProviderFactory{}
providerFactories[config.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{}
providerFactories[config.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{}
providerFactories[config.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
providerFactories[config.ChannelTypeMistral] = mistral.MistralProviderFactory{}
providerFactories[config.ChannelTypeGroq] = groq.GroqProviderFactory{}
providerFactories[config.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
providerFactories[config.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
providerFactories[config.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
providerFactories[config.ChannelTypeCohere] = cohere.CohereProviderFactory{}
providerFactories[config.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
providerFactories[config.ChannelTypeCoze] = coze.CozeProviderFactory{}
providerFactories[config.ChannelTypeOllama] = ollama.OllamaProviderFactory{}
providerFactories[config.ChannelTypeMoonshot] = moonshot.MoonshotProviderFactory{}
providerFactories[config.ChannelTypeLingyi] = lingyi.LingyiProviderFactory{}
providerFactories[config.ChannelTypeHunyuan] = hunyuan.HunyuanProviderFactory{}
}
// 获取供应商
func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface {
factory, ok := providerFactories[channel.Type]
var provider base.ProviderInterface
if !ok {
// 处理未找到的供应商工厂
baseURL := config.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
if baseURL == "" {
return nil
}
provider = openai.CreateOpenAIProvider(channel, baseURL)
} else {
provider = factory.Create(channel)
}
provider.SetContext(c)
return provider
}