mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-18 06:03:42 +08:00
♻️ refactor: provider refactor (#41)
* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
@@ -3,9 +3,9 @@ package base
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
@@ -13,11 +13,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var StopFinishReason = "stop"
|
||||
var StopFinishReasonToolFunction = "tool_calls"
|
||||
var StopFinishReasonCallFunction = "function_call"
|
||||
|
||||
type BaseProvider struct {
|
||||
type ProviderConfig struct {
|
||||
BaseURL string
|
||||
Completions string
|
||||
ChatCompletions string
|
||||
@@ -29,8 +25,15 @@ type BaseProvider struct {
|
||||
ImagesGenerations string
|
||||
ImagesEdit string
|
||||
ImagesVariations string
|
||||
Context *gin.Context
|
||||
Channel *model.Channel
|
||||
}
|
||||
|
||||
type BaseProvider struct {
|
||||
OriginalModel string
|
||||
Usage *types.Usage
|
||||
Config ProviderConfig
|
||||
Context *gin.Context
|
||||
Channel *model.Channel
|
||||
Requester *requester.HTTPRequester
|
||||
}
|
||||
|
||||
// 获取基础URL
|
||||
@@ -39,11 +42,7 @@ func (p *BaseProvider) GetBaseURL() string {
|
||||
return p.Channel.GetBaseURL()
|
||||
}
|
||||
|
||||
return p.BaseURL
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SetChannel(channel *model.Channel) {
|
||||
p.Channel = channel
|
||||
return p.Config.BaseURL
|
||||
}
|
||||
|
||||
// 获取完整请求URL
|
||||
@@ -62,104 +61,85 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if rawOutput {
|
||||
for k, v := range resp.Header {
|
||||
p.Context.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err := io.Copy(p.Context.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
jsonResponse, err := json.Marshal(openAIResponse)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = p.Context.Writer.Write(jsonResponse)
|
||||
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
func (p *BaseProvider) GetUsage() *types.Usage {
|
||||
return p.Usage
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理响应
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
p.Context.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(p.Context.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil
|
||||
func (p *BaseProvider) SetUsage(usage *types.Usage) {
|
||||
p.Usage = usage
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SupportAPI(relayMode int) bool {
|
||||
func (p *BaseProvider) SetContext(c *gin.Context) {
|
||||
p.Context = c
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SetOriginalModel(ModelName string) {
|
||||
p.OriginalModel = ModelName
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetOriginalModel() string {
|
||||
return p.OriginalModel
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetChannel() *model.Channel {
|
||||
return p.Channel
|
||||
}
|
||||
|
||||
func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) {
|
||||
p.OriginalModel = modelName
|
||||
|
||||
modelMapping := p.Channel.GetModelMapping()
|
||||
|
||||
if modelMapping == "" || modelMapping == "{}" {
|
||||
return modelName, nil
|
||||
}
|
||||
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if modelMap[modelName] != "" {
|
||||
return modelMap[modelName], nil
|
||||
}
|
||||
|
||||
return modelName, nil
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetAPIUri(relayMode int) string {
|
||||
switch relayMode {
|
||||
case common.RelayModeChatCompletions:
|
||||
return p.ChatCompletions != ""
|
||||
return p.Config.ChatCompletions
|
||||
case common.RelayModeCompletions:
|
||||
return p.Completions != ""
|
||||
return p.Config.Completions
|
||||
case common.RelayModeEmbeddings:
|
||||
return p.Embeddings != ""
|
||||
return p.Config.Embeddings
|
||||
case common.RelayModeAudioSpeech:
|
||||
return p.AudioSpeech != ""
|
||||
return p.Config.AudioSpeech
|
||||
case common.RelayModeAudioTranscription:
|
||||
return p.AudioTranscriptions != ""
|
||||
return p.Config.AudioTranscriptions
|
||||
case common.RelayModeAudioTranslation:
|
||||
return p.AudioTranslations != ""
|
||||
return p.Config.AudioTranslations
|
||||
case common.RelayModeModerations:
|
||||
return p.Moderation != ""
|
||||
return p.Config.Moderation
|
||||
case common.RelayModeImagesGenerations:
|
||||
return p.ImagesGenerations != ""
|
||||
return p.Config.ImagesGenerations
|
||||
case common.RelayModeImagesEdits:
|
||||
return p.ImagesEdit != ""
|
||||
return p.Config.ImagesEdit
|
||||
case common.RelayModeImagesVariations:
|
||||
return p.ImagesVariations != ""
|
||||
return p.Config.ImagesVariations
|
||||
default:
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types.OpenAIErrorWithStatusCode) {
|
||||
url = p.GetAPIUri(relayMode)
|
||||
if url == "" {
|
||||
err = common.StringErrorWrapper("The API interface is not supported", "unsupported_api", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
7
providers/base/handler.go
Normal file
7
providers/base/handler.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package base
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type BaseHandler struct {
|
||||
Usage *types.Usage
|
||||
}
|
||||
@@ -2,84 +2,108 @@ package base
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Requestable interface {
|
||||
types.CompletionRequest | types.ChatCompletionRequest | types.EmbeddingRequest | types.ModerationRequest | types.SpeechAudioRequest | types.AudioRequest | types.ImageRequest | types.ImageEditRequest
|
||||
}
|
||||
|
||||
// 基础接口
|
||||
type ProviderInterface interface {
|
||||
GetBaseURL() string
|
||||
GetFullRequestURL(requestURL string, modelName string) string
|
||||
GetRequestHeaders() (headers map[string]string)
|
||||
SupportAPI(relayMode int) bool
|
||||
SetChannel(channel *model.Channel)
|
||||
// 获取基础URL
|
||||
// GetBaseURL() string
|
||||
// 获取完整请求URL
|
||||
// GetFullRequestURL(requestURL string, modelName string) string
|
||||
// 获取请求头
|
||||
// GetRequestHeaders() (headers map[string]string)
|
||||
// 获取用量
|
||||
GetUsage() *types.Usage
|
||||
// 设置用量
|
||||
SetUsage(usage *types.Usage)
|
||||
// 设置Context
|
||||
SetContext(c *gin.Context)
|
||||
// 设置原始模型
|
||||
SetOriginalModel(ModelName string)
|
||||
// 获取原始模型
|
||||
GetOriginalModel() string
|
||||
|
||||
// SupportAPI(relayMode int) bool
|
||||
GetChannel() *model.Channel
|
||||
ModelMappingHandler(modelName string) (string, error)
|
||||
}
|
||||
|
||||
// 完成接口
|
||||
type CompletionInterface interface {
|
||||
ProviderInterface
|
||||
CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 聊天接口
|
||||
type ChatInterface interface {
|
||||
ProviderInterface
|
||||
ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 嵌入接口
|
||||
type EmbeddingsInterface interface {
|
||||
ProviderInterface
|
||||
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 审查接口
|
||||
type ModerationInterface interface {
|
||||
ProviderInterface
|
||||
ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 文字转语音接口
|
||||
type SpeechInterface interface {
|
||||
ProviderInterface
|
||||
SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 语音转文字接口
|
||||
type TranscriptionsInterface interface {
|
||||
ProviderInterface
|
||||
TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 语音翻译接口
|
||||
type TranslationInterface interface {
|
||||
ProviderInterface
|
||||
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 图片生成接口
|
||||
type ImageGenerationsInterface interface {
|
||||
ProviderInterface
|
||||
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 图片编辑接口
|
||||
type ImageEditsInterface interface {
|
||||
ProviderInterface
|
||||
ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type ImageVariationsInterface interface {
|
||||
ProviderInterface
|
||||
ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 余额接口
|
||||
type BalanceInterface interface {
|
||||
Balance(channel *model.Channel) (float64, error)
|
||||
Balance() (float64, error)
|
||||
}
|
||||
|
||||
type ProviderResponseHandler interface {
|
||||
// 响应处理函数
|
||||
ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
// type ProviderResponseHandler interface {
|
||||
// // 响应处理函数
|
||||
// ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
// }
|
||||
|
||||
Reference in New Issue
Block a user