one-api/providers/xunfei/chat.go
2024-01-03 16:37:27 +08:00

301 lines
9.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package xunfei
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/providers/base"
"one-api/types"
"time"
"github.com/gorilla/websocket"
)
func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
}
if request.Stream {
return p.sendStreamRequest(dataChan, stopChan, request.GetFunctionCate())
} else {
return p.sendRequest(dataChan, stopChan, request.GetFunctionCate())
}
}
func (p *XunfeiProvider) sendRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
}
if xunfeiResponse.Header.Code != 0 {
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := p.responseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
return nil, common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
p.Context.Writer.Header().Set("Content-Type", "application/json")
_, _ = p.Context.Writer.Write(jsonResponse)
return usage, nil
}
func (p *XunfeiProvider) sendStreamRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
// 等待第一个dataChan的响应
xunfeiResponse, ok := <-dataChan
if !ok {
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response channel closed"), "xunfei_response_error", http.StatusInternalServerError)
}
if xunfeiResponse.Header.Code != 0 {
errWithCode = common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
return nil, errWithCode
}
// 如果第一个响应没有错误设置StreamHeaders并开始streaming
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
// 处理第一个响应
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// 处理后续的响应
for {
select {
case xunfeiResponse, ok := <-dataChan:
if !ok {
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
}
})
return usage, nil
}
func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
}
xunfeiRequest := XunfeiChatRequest{}
if request.Tools != nil {
functions := make([]*types.ChatCompletionFunction, 0, len(request.Tools))
for _, tool := range request.Tools {
functions = append(functions, &tool.Function)
}
xunfeiRequest.Payload.Functions = &XunfeiChatPayloadFunctions{}
xunfeiRequest.Payload.Functions.Text = functions
} else if request.Functions != nil {
xunfeiRequest.Payload.Functions = &XunfeiChatPayloadFunctions{}
xunfeiRequest.Payload.Functions.Text = request.Functions
}
xunfeiRequest.Header.AppId = p.apiId
xunfeiRequest.Parameter.Chat.Domain = p.domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
return &xunfeiRequest
}
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, functionCate string) *types.ChatCompletionResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
choice := types.ChatCompletionChoice{
Index: 0,
FinishReason: base.StopFinishReason,
}
xunfeiText := response.Payload.Choices.Text[0]
if xunfeiText.FunctionCall != nil {
choice.Message = types.ChatCompletionMessage{
Role: "assistant",
}
if functionCate == "tool" {
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: response.Header.Sid,
Type: "function",
Function: *xunfeiText.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
} else {
choice.Message.FunctionCall = xunfeiText.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
}
} else {
choice.Message = types.ChatCompletionMessage{
Role: "assistant",
Content: xunfeiText.Content,
}
}
fullTextResponse := types.ChatCompletionResponse{
ID: response.Header.Sid,
Object: "chat.completion",
Model: "SparkDesk",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Usage: &response.Payload.Usage.Text,
}
return &fullTextResponse
}
func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 {
return nil, nil, err
}
data := p.requestOpenAI2Xunfei(textRequest)
err = conn.WriteJSON(data)
if err != nil {
return nil, nil, err
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
common.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
common.SysError("error closing websocket connection: " + err.Error())
}
break
}
}
stopChan <- true
}()
return dataChan, stopChan, nil
}
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse, functionCate string) *types.ChatCompletionStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
var choice types.ChatCompletionStreamChoice
xunfeiText := xunfeiResponse.Payload.Choices.Text[0]
if xunfeiText.FunctionCall != nil {
if functionCate == "tool" {
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: xunfeiResponse.Header.Sid,
Index: 0,
Type: "function",
Function: *xunfeiText.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
} else {
choice.Delta.FunctionCall = xunfeiText.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
}
} else {
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &base.StopFinishReason
}
}
response := types.ChatCompletionStreamResponse{
ID: xunfeiResponse.Header.Sid,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}