♻️ refactor: provider refactor (#41)

* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
Buer
2024-01-19 02:47:10 +08:00
committed by GitHub
parent 0bfe1f5779
commit ef041e28a1
96 changed files with 4339 additions and 3276 deletions

View File

@@ -7,31 +7,53 @@ import (
"fmt"
"net/url"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type XunfeiProviderFactory struct{}
// 创建 XunfeiProvider
func (f XunfeiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func (f XunfeiProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &XunfeiProvider{
BaseProvider: base.BaseProvider{
BaseURL: "wss://spark-api.xf-yun.com",
ChatCompletions: "true",
Context: c,
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, nil),
},
wsRequester: requester.NewWSRequester(channel.Proxy),
}
}
// https://www.xfyun.cn/doc/spark/Web.html
type XunfeiProvider struct {
base.BaseProvider
domain string
apiId string
domain string
apiId string
wsRequester *requester.WSRequester
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "wss://spark-api.xf-yun.com",
ChatCompletions: "/",
}
}
// 错误处理
func errorHandle(xunfeiError *XunfeiChatResponse) *types.OpenAIError {
if xunfeiError.Header.Code == 0 {
return nil
}
return &types.OpenAIError{
Message: xunfeiError.Header.Message,
Type: "xunfei_error",
Code: xunfeiError.Header.Code,
}
}
// 获取请求头
@@ -68,7 +90,7 @@ func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (stri
if apiVersion != "v1.1" {
domain += strings.Split(apiVersion, ".")[0]
}
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.BaseURL, apiVersion), apiKey, apiSecret)
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.Config.BaseURL, apiVersion), apiKey, apiSecret)
return domain, authUrl
}

View File

@@ -2,140 +2,93 @@ package xunfei
import (
"encoding/json"
"fmt"
"errors"
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/common/requester"
"one-api/types"
"time"
"strings"
"github.com/gorilla/websocket"
)
func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
}
if request.Stream {
return p.sendStreamRequest(dataChan, stopChan, request.GetFunctionCate())
} else {
return p.sendRequest(dataChan, stopChan, request.GetFunctionCate())
}
type xunfeiHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *XunfeiProvider) sendRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
}
if xunfeiResponse.Header.Code != 0 {
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := p.responseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
return nil, common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
_, _ = p.Context.Writer.Write(jsonResponse)
return usage, nil
}
func (p *XunfeiProvider) sendStreamRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
// 等待第一个dataChan的响应
xunfeiResponse, ok := <-dataChan
if !ok {
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response channel closed"), "xunfei_response_error", http.StatusInternalServerError)
}
if xunfeiResponse.Header.Code != 0 {
errWithCode = common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
wsConn, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
// 如果第一个响应没有错误设置StreamHeaders并开始streaming
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
// 处理第一个响应
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
xunfeiRequest := p.convertFromChatOpenai(request)
chatHandler := &xunfeiHandler{
Usage: p.Usage,
Request: request,
}
stream, errWithCode := requester.SendWSJsonRequest[XunfeiChatResponse](wsConn, xunfeiRequest, chatHandler.handlerNotStream)
if errWithCode != nil {
return nil, errWithCode
}
return chatHandler.convertToChatOpenai(stream)
// 处理后续的响应
for {
select {
case xunfeiResponse, ok := <-dataChan:
if !ok {
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
}
})
return usage, nil
}
func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest {
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
wsConn, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
xunfeiRequest := p.convertFromChatOpenai(request)
chatHandler := &xunfeiHandler{
Usage: p.Usage,
Request: request,
}
return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream)
}
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
authUrl := p.GetFullRequestURL(url, request.Model)
wsConn, err := p.wsRequester.NewRequest(authUrl, nil)
if err != nil {
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
}
return wsConn, nil
}
func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Role: types.ChatMessageRoleUser,
Content: message.StringContent(),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
Role: types.ChatMessageRoleAssistant,
Content: "Okay",
})
} else if message.Role == types.ChatMessageRoleFunction {
messages = append(messages, XunfeiMessage{
Role: types.ChatMessageRoleUser,
Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(),
})
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
@@ -143,6 +96,7 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque
})
}
}
xunfeiRequest := XunfeiChatRequest{}
if request.Tools != nil {
@@ -166,35 +120,57 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque
return &xunfeiRequest
}
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, functionCate string) *types.ChatCompletionResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
var content string
var xunfeiResponse XunfeiChatResponse
for {
response, err := stream.Recv()
if err != nil && !errors.Is(err, io.EOF) {
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
}
if errors.Is(err, io.EOF) && response == nil {
break
}
if len((*response)[0].Payload.Choices.Text) == 0 {
continue
}
xunfeiResponse = (*response)[0]
content += xunfeiResponse.Payload.Choices.Text[0].Content
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
choice := types.ChatCompletionChoice{
Index: 0,
FinishReason: base.StopFinishReason,
FinishReason: types.FinishReasonStop,
}
xunfeiText := response.Payload.Choices.Text[0]
xunfeiText := xunfeiResponse.Payload.Choices.Text[0]
if xunfeiText.FunctionCall != nil {
choice.Message = types.ChatCompletionMessage{
Role: "assistant",
}
if functionCate == "tool" {
if h.Request.Tools != nil {
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: response.Header.Sid,
Id: xunfeiResponse.Header.Sid,
Type: "function",
Function: *xunfeiText.FunctionCall,
Function: xunfeiText.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
choice.FinishReason = types.FinishReasonToolCalls
} else {
choice.Message.FunctionCall = xunfeiText.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
choice.FinishReason = types.FinishReasonFunctionCall
}
} else {
@@ -204,97 +180,128 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, fun
}
}
fullTextResponse := types.ChatCompletionResponse{
ID: response.Header.Sid,
fullTextResponse := &types.ChatCompletionResponse{
ID: xunfeiResponse.Header.Sid,
Object: "chat.completion",
Model: "SparkDesk",
Model: h.Request.Model,
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Usage: &response.Payload.Usage.Text,
Usage: &xunfeiResponse.Payload.Usage.Text,
}
return &fullTextResponse
return fullTextResponse, nil
}
func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "{") {
*rawLine = nil
return nil, nil
}
conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 {
return nil, nil, err
}
data := p.requestOpenAI2Xunfei(textRequest)
err = conn.WriteJSON(data)
var xunfeiChatResponse XunfeiChatResponse
err := json.Unmarshal(*rawLine, &xunfeiChatResponse)
if err != nil {
return nil, nil, err
return nil, common.ErrorToOpenAIError(err)
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
common.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
common.SysError("error closing websocket connection: " + err.Error())
}
break
}
}
stopChan <- true
}()
error := errorHandle(&xunfeiChatResponse)
if error != nil {
return nil, error
}
return dataChan, stopChan, nil
if xunfeiChatResponse.Payload.Choices.Status == 2 {
*isFinished = true
}
h.Usage.PromptTokens = xunfeiChatResponse.Payload.Usage.Text.PromptTokens
h.Usage.CompletionTokens = xunfeiChatResponse.Payload.Usage.Text.CompletionTokens
h.Usage.TotalTokens = xunfeiChatResponse.Payload.Usage.Text.TotalTokens
return &xunfeiChatResponse, nil
}
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse, functionCate string) *types.ChatCompletionStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error {
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
if err != nil {
return err
}
var choice types.ChatCompletionStreamChoice
xunfeiText := xunfeiResponse.Payload.Choices.Text[0]
if *rawLine == nil {
return nil
}
*response = append(*response, *xunfeiChatResponse)
return nil
}
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
if err != nil {
return err
}
if *rawLine == nil {
return nil
}
return h.convertToOpenaiStream(xunfeiChatResponse, response)
}
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
choice := types.ChatCompletionStreamChoice{
Index: 0,
Delta: types.ChatCompletionStreamChoiceDelta{
Role: types.ChatMessageRoleAssistant,
},
}
xunfeiText := xunfeiChatResponse.Payload.Choices.Text[0]
if xunfeiText.FunctionCall != nil {
if functionCate == "tool" {
if h.Request.Tools != nil {
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: xunfeiResponse.Header.Sid,
Id: xunfeiChatResponse.Header.Sid,
Index: 0,
Type: "function",
Function: *xunfeiText.FunctionCall,
Function: xunfeiText.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
choice.FinishReason = types.FinishReasonToolCalls
} else {
choice.Delta.FunctionCall = xunfeiText.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
choice.FinishReason = types.FinishReasonFunctionCall
}
} else {
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &base.StopFinishReason
choice.Delta.Content = xunfeiChatResponse.Payload.Choices.Text[0].Content
if xunfeiChatResponse.Payload.Choices.Status == 2 {
choice.FinishReason = types.FinishReasonStop
}
}
response := types.ChatCompletionStreamResponse{
ID: xunfeiResponse.Header.Sid,
chatCompletion := types.ChatCompletionStreamResponse{
ID: xunfeiChatResponse.Header.Sid,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",
Choices: []types.ChatCompletionStreamChoice{choice},
Model: h.Request.Model,
}
return &response
if xunfeiText.FunctionCall == nil {
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletion)
} else {
choices := choice.ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := chatCompletion
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy)
}
}
return nil
}