mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-14 20:23:46 +08:00
🎨 Change the method of getting channel parameters
This commit is contained in:
@@ -32,9 +32,9 @@ type AliProvider struct {
|
||||
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||
if p.Context.GetString("plugin") != "" {
|
||||
headers["X-DashScope-Plugin"] = p.Context.GetString("plugin")
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
|
||||
if p.Channel.Other != "" {
|
||||
headers["X-DashScope-Plugin"] = p.Channel.Other
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
@@ -27,7 +27,7 @@ type AzureSpeechProvider struct {
|
||||
// 获取请求头
|
||||
func (p *AzureSpeechProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
headers["Ocp-Apim-Subscription-Key"] = p.Context.GetString("api_key")
|
||||
headers["Ocp-Apim-Subscription-Key"] = p.Channel.Key
|
||||
headers["Content-Type"] = "application/ssml+xml"
|
||||
headers["User-Agent"] = "OneAPI"
|
||||
// headers["X-Microsoft-OutputFormat"] = "audio-16khz-128kbitrate-mono-mp3"
|
||||
|
||||
@@ -63,7 +63,7 @@ func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
|
||||
apiKey := p.Context.GetString("api_key")
|
||||
apiKey := p.Channel.Key
|
||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
var accessToken BaiduAccessToken
|
||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
@@ -28,17 +29,22 @@ type BaseProvider struct {
|
||||
ImagesVariations string
|
||||
Proxy string
|
||||
Context *gin.Context
|
||||
Channel *model.Channel
|
||||
}
|
||||
|
||||
// 获取基础URL
|
||||
func (p *BaseProvider) GetBaseURL() string {
|
||||
if p.Context.GetString("base_url") != "" {
|
||||
return p.Context.GetString("base_url")
|
||||
if p.Channel.GetBaseURL() != "" {
|
||||
return p.Channel.GetBaseURL()
|
||||
}
|
||||
|
||||
return p.BaseURL
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SetChannel(channel *model.Channel) {
|
||||
p.Channel = channel
|
||||
}
|
||||
|
||||
// 获取完整请求URL
|
||||
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
@@ -12,6 +12,7 @@ type ProviderInterface interface {
|
||||
GetFullRequestURL(requestURL string, modelName string) string
|
||||
GetRequestHeaders() (headers map[string]string)
|
||||
SupportAPI(relayMode int) bool
|
||||
SetChannel(channel *model.Channel)
|
||||
}
|
||||
|
||||
// 完成接口
|
||||
|
||||
@@ -28,7 +28,7 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
|
||||
headers["x-api-key"] = p.Context.GetString("api_key")
|
||||
headers["x-api-key"] = p.Channel.Key
|
||||
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
|
||||
@@ -28,8 +28,8 @@ type GeminiProvider struct {
|
||||
func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
version := "v1"
|
||||
if p.Context.GetString("api_version") != "" {
|
||||
version = p.Context.GetString("api_version")
|
||||
if p.Channel.Other != "" {
|
||||
version = p.Channel.Other
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL)
|
||||
@@ -40,7 +40,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
|
||||
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["x-goog-api-key"] = p.Context.GetString("api_key")
|
||||
headers["x-goog-api-key"] = p.Channel.Key
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
if p.IsAzure {
|
||||
apiVersion := p.Context.GetString("api_version")
|
||||
apiVersion := p.Channel.Other
|
||||
if modelName == "dall-e-2" {
|
||||
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
||||
// 已经没有dall-e-2了,所以暂时写死
|
||||
@@ -85,9 +85,9 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
if p.IsAzure {
|
||||
headers["api-key"] = p.Context.GetString("api_key")
|
||||
headers["api-key"] = p.Channel.Key
|
||||
} else {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
@@ -29,7 +29,7 @@ type PalmProvider struct {
|
||||
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["x-goog-api-key"] = p.Context.GetString("api_key")
|
||||
headers["x-goog-api-key"] = p.Channel.Key
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package providers
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/aigc2d"
|
||||
"one-api/providers/aiproxy"
|
||||
"one-api/providers/ali"
|
||||
@@ -55,19 +56,23 @@ func init() {
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
|
||||
factory, ok := providerFactories[channelType]
|
||||
func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface {
|
||||
factory, ok := providerFactories[channel.Type]
|
||||
var provider base.ProviderInterface
|
||||
if !ok {
|
||||
// 处理未找到的供应商工厂
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
if baseURL != "" {
|
||||
return openai.CreateOpenAIProvider(c, baseURL)
|
||||
if baseURL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
provider = openai.CreateOpenAIProvider(c, baseURL)
|
||||
}
|
||||
return factory.Create(c)
|
||||
provider = factory.Create(c)
|
||||
provider.SetChannel(channel)
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (p *TencentProvider) parseTencentConfig(config string) (appId int64, secret
|
||||
}
|
||||
|
||||
func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
|
||||
apiKey := p.Context.GetString("api_key")
|
||||
apiKey := p.Channel.Key
|
||||
appId, secretId, secretKey, err := p.parseTencentConfig(apiKey)
|
||||
if err != nil {
|
||||
return ""
|
||||
|
||||
@@ -42,7 +42,7 @@ func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
splits := strings.Split(p.Context.GetString("api_key"), "|")
|
||||
splits := strings.Split(p.Channel.Key, "|")
|
||||
if len(splits) != 3 {
|
||||
return ""
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (stri
|
||||
query := p.Context.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = p.Context.GetString("api_version")
|
||||
apiVersion = p.Channel.Key
|
||||
}
|
||||
if apiVersion == "" {
|
||||
apiVersion = "v1.1"
|
||||
|
||||
@@ -49,7 +49,7 @@ func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) s
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) getZhipuToken() string {
|
||||
apikey := p.Context.GetString("api_key")
|
||||
apikey := p.Channel.Key
|
||||
data, ok := zhipuTokens.Load(apikey)
|
||||
if ok {
|
||||
tokenData := data.(zhipuTokenData)
|
||||
|
||||
Reference in New Issue
Block a user