Merge branch 'songquanpeng' into sync_upstream

This commit is contained in:
Martial BE
2023-12-25 11:23:28 +08:00
24 changed files with 251 additions and 49 deletions

View File

@@ -39,6 +39,7 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
ID: aliResponse.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: aliResponse.Model,
Choices: []types.ChatCompletionChoice{choice},
Usage: &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
@@ -50,6 +51,8 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
return
}
const AliEnableSearchModelSuffix = "-internet"
// 获取聊天请求体
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
@@ -60,11 +63,23 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Model: aliModel,
Input: AliInput{
Messages: messages,
},
Parameters: AliParameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
}
}
@@ -86,7 +101,7 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
}
if request.Stream {
usage, errWithCode = p.sendStreamRequest(req)
usage, errWithCode = p.sendStreamRequest(req, request.Model)
if errWithCode != nil {
return
}
@@ -100,7 +115,9 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
}
} else {
aliResponse := &AliChatResponse{}
aliResponse := &AliChatResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, aliResponse, false)
if errWithCode != nil {
return
@@ -128,14 +145,14 @@ func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ty
ID: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Model: aliResponse.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}
// 发送流请求
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
usage = &types.Usage{}
@@ -181,7 +198,7 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
lastResponseText := ""
// lastResponseText := ""
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -196,9 +213,10 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
aliResponse.Model = model
response := p.streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
// response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
// lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())