♻️ refactor: provider refactor (#41)

* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
Buer
2024-01-19 02:47:10 +08:00
committed by GitHub
parent 0bfe1f5779
commit ef041e28a1
96 changed files with 4339 additions and 3276 deletions

View File

@@ -3,9 +3,9 @@ package base
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"strings"
@@ -13,11 +13,7 @@ import (
"github.com/gin-gonic/gin"
)
var StopFinishReason = "stop"
var StopFinishReasonToolFunction = "tool_calls"
var StopFinishReasonCallFunction = "function_call"
type BaseProvider struct {
type ProviderConfig struct {
BaseURL string
Completions string
ChatCompletions string
@@ -29,8 +25,15 @@ type BaseProvider struct {
ImagesGenerations string
ImagesEdit string
ImagesVariations string
Context *gin.Context
Channel *model.Channel
}
type BaseProvider struct {
OriginalModel string
Usage *types.Usage
Config ProviderConfig
Context *gin.Context
Channel *model.Channel
Requester *requester.HTTPRequester
}
// 获取基础URL
@@ -39,11 +42,7 @@ func (p *BaseProvider) GetBaseURL() string {
return p.Channel.GetBaseURL()
}
return p.BaseURL
}
func (p *BaseProvider) SetChannel(channel *model.Channel) {
p.Channel = channel
return p.Config.BaseURL
}
// 获取完整请求URL
@@ -62,104 +61,85 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
}
}
// 发送请求
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy)
if openAIErrorWithStatusCode != nil {
return
}
defer resp.Body.Close()
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
if openAIErrorWithStatusCode != nil {
return
}
if rawOutput {
for k, v := range resp.Header {
p.Context.Writer.Header().Set(k, v[0])
}
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(p.Context.Writer, resp.Body)
if err != nil {
return common.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
} else {
jsonResponse, err := json.Marshal(openAIResponse)
if err != nil {
return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = p.Context.Writer.Write(jsonResponse)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
}
return nil
func (p *BaseProvider) GetUsage() *types.Usage {
return p.Usage
}
func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
// 发送请求
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
}
common.PutHttpClient(client)
defer resp.Body.Close()
// 处理响应
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp)
}
for k, v := range resp.Header {
p.Context.Writer.Header().Set(k, v[0])
}
p.Context.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(p.Context.Writer, resp.Body)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
func (p *BaseProvider) SetUsage(usage *types.Usage) {
p.Usage = usage
}
func (p *BaseProvider) SupportAPI(relayMode int) bool {
func (p *BaseProvider) SetContext(c *gin.Context) {
p.Context = c
}
func (p *BaseProvider) SetOriginalModel(ModelName string) {
p.OriginalModel = ModelName
}
func (p *BaseProvider) GetOriginalModel() string {
return p.OriginalModel
}
func (p *BaseProvider) GetChannel() *model.Channel {
return p.Channel
}
func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) {
p.OriginalModel = modelName
modelMapping := p.Channel.GetModelMapping()
if modelMapping == "" || modelMapping == "{}" {
return modelName, nil
}
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return "", err
}
if modelMap[modelName] != "" {
return modelMap[modelName], nil
}
return modelName, nil
}
func (p *BaseProvider) GetAPIUri(relayMode int) string {
switch relayMode {
case common.RelayModeChatCompletions:
return p.ChatCompletions != ""
return p.Config.ChatCompletions
case common.RelayModeCompletions:
return p.Completions != ""
return p.Config.Completions
case common.RelayModeEmbeddings:
return p.Embeddings != ""
return p.Config.Embeddings
case common.RelayModeAudioSpeech:
return p.AudioSpeech != ""
return p.Config.AudioSpeech
case common.RelayModeAudioTranscription:
return p.AudioTranscriptions != ""
return p.Config.AudioTranscriptions
case common.RelayModeAudioTranslation:
return p.AudioTranslations != ""
return p.Config.AudioTranslations
case common.RelayModeModerations:
return p.Moderation != ""
return p.Config.Moderation
case common.RelayModeImagesGenerations:
return p.ImagesGenerations != ""
return p.Config.ImagesGenerations
case common.RelayModeImagesEdits:
return p.ImagesEdit != ""
return p.Config.ImagesEdit
case common.RelayModeImagesVariations:
return p.ImagesVariations != ""
return p.Config.ImagesVariations
default:
return false
return ""
}
}
func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types.OpenAIErrorWithStatusCode) {
url = p.GetAPIUri(relayMode)
if url == "" {
err = common.StringErrorWrapper("The API interface is not supported", "unsupported_api", http.StatusNotImplemented)
return
}
return
}

View File

@@ -0,0 +1,7 @@
package base
import "one-api/types"
type BaseHandler struct {
Usage *types.Usage
}

View File

@@ -2,84 +2,108 @@ package base
import (
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Requestable interface {
types.CompletionRequest | types.ChatCompletionRequest | types.EmbeddingRequest | types.ModerationRequest | types.SpeechAudioRequest | types.AudioRequest | types.ImageRequest | types.ImageEditRequest
}
// 基础接口
type ProviderInterface interface {
GetBaseURL() string
GetFullRequestURL(requestURL string, modelName string) string
GetRequestHeaders() (headers map[string]string)
SupportAPI(relayMode int) bool
SetChannel(channel *model.Channel)
// 获取基础URL
// GetBaseURL() string
// 获取完整请求URL
// GetFullRequestURL(requestURL string, modelName string) string
// 获取请求头
// GetRequestHeaders() (headers map[string]string)
// 获取用量
GetUsage() *types.Usage
// 设置用量
SetUsage(usage *types.Usage)
// 设置Context
SetContext(c *gin.Context)
// 设置原始模型
SetOriginalModel(ModelName string)
// 获取原始模型
GetOriginalModel() string
// SupportAPI(relayMode int) bool
GetChannel() *model.Channel
ModelMappingHandler(modelName string) (string, error)
}
// 完成接口
type CompletionInterface interface {
ProviderInterface
CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode)
}
// 聊天接口
type ChatInterface interface {
ProviderInterface
ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode)
}
// 嵌入接口
type EmbeddingsInterface interface {
ProviderInterface
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode)
}
// 审查接口
type ModerationInterface interface {
ProviderInterface
ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode)
}
// 文字转语音接口
type SpeechInterface interface {
ProviderInterface
SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode)
}
// 语音转文字接口
type TranscriptionsInterface interface {
ProviderInterface
TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode)
}
// 语音翻译接口
type TranslationInterface interface {
ProviderInterface
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode)
}
// 图片生成接口
type ImageGenerationsInterface interface {
ProviderInterface
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
}
// 图片编辑接口
type ImageEditsInterface interface {
ProviderInterface
ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
}
type ImageVariationsInterface interface {
ProviderInterface
ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
}
// 余额接口
type BalanceInterface interface {
Balance(channel *model.Channel) (float64, error)
Balance() (float64, error)
}
type ProviderResponseHandler interface {
// 响应处理函数
ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
}
// type ProviderResponseHandler interface {
// // 响应处理函数
// ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
// }