mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-29 22:56:39 +08:00
150 lines
3.4 KiB
Go
150 lines
3.4 KiB
Go
package base
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/common/requester"
|
|
"one-api/model"
|
|
"one-api/types"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type ProviderConfig struct {
|
|
BaseURL string
|
|
Completions string
|
|
ChatCompletions string
|
|
Embeddings string
|
|
AudioSpeech string
|
|
Moderation string
|
|
AudioTranscriptions string
|
|
AudioTranslations string
|
|
ImagesGenerations string
|
|
ImagesEdit string
|
|
ImagesVariations string
|
|
}
|
|
|
|
type BaseProvider struct {
|
|
OriginalModel string
|
|
Usage *types.Usage
|
|
Config ProviderConfig
|
|
Context *gin.Context
|
|
Channel *model.Channel
|
|
Requester *requester.HTTPRequester
|
|
}
|
|
|
|
// 获取基础URL
|
|
func (p *BaseProvider) GetBaseURL() string {
|
|
if p.Channel.GetBaseURL() != "" {
|
|
return p.Channel.GetBaseURL()
|
|
}
|
|
|
|
return p.Config.BaseURL
|
|
}
|
|
|
|
// 获取完整请求URL
|
|
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
|
|
|
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
}
|
|
|
|
// 获取请求头
|
|
func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
|
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
|
if headers["Content-Type"] == "" {
|
|
headers["Content-Type"] = "application/json"
|
|
}
|
|
}
|
|
|
|
func (p *BaseProvider) GetUsage() *types.Usage {
|
|
return p.Usage
|
|
}
|
|
|
|
func (p *BaseProvider) SetUsage(usage *types.Usage) {
|
|
p.Usage = usage
|
|
}
|
|
|
|
func (p *BaseProvider) SetContext(c *gin.Context) {
|
|
p.Context = c
|
|
}
|
|
|
|
func (p *BaseProvider) SetOriginalModel(ModelName string) {
|
|
p.OriginalModel = ModelName
|
|
}
|
|
|
|
func (p *BaseProvider) GetOriginalModel() string {
|
|
return p.OriginalModel
|
|
}
|
|
|
|
func (p *BaseProvider) GetChannel() *model.Channel {
|
|
return p.Channel
|
|
}
|
|
|
|
func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) {
|
|
p.OriginalModel = modelName
|
|
|
|
modelMapping := p.Channel.GetModelMapping()
|
|
|
|
if modelMapping == "" || modelMapping == "{}" {
|
|
return modelName, nil
|
|
}
|
|
|
|
modelMap := make(map[string]string)
|
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if modelMap[modelName] != "" {
|
|
return modelMap[modelName], nil
|
|
}
|
|
|
|
return modelName, nil
|
|
}
|
|
|
|
func (p *BaseProvider) GetAPIUri(relayMode int) string {
|
|
switch relayMode {
|
|
case common.RelayModeChatCompletions:
|
|
return p.Config.ChatCompletions
|
|
case common.RelayModeCompletions:
|
|
return p.Config.Completions
|
|
case common.RelayModeEmbeddings:
|
|
return p.Config.Embeddings
|
|
case common.RelayModeAudioSpeech:
|
|
return p.Config.AudioSpeech
|
|
case common.RelayModeAudioTranscription:
|
|
return p.Config.AudioTranscriptions
|
|
case common.RelayModeAudioTranslation:
|
|
return p.Config.AudioTranslations
|
|
case common.RelayModeModerations:
|
|
return p.Config.Moderation
|
|
case common.RelayModeImagesGenerations:
|
|
return p.Config.ImagesGenerations
|
|
case common.RelayModeImagesEdits:
|
|
return p.Config.ImagesEdit
|
|
case common.RelayModeImagesVariations:
|
|
return p.Config.ImagesVariations
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types.OpenAIErrorWithStatusCode) {
|
|
url = p.GetAPIUri(relayMode)
|
|
if url == "" {
|
|
err = common.StringErrorWrapper("The API interface is not supported", "unsupported_api", http.StatusNotImplemented)
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (p *BaseProvider) GetRequester() *requester.HTTPRequester {
|
|
return p.Requester
|
|
}
|