🎨 Support qwen-vl-plus

This commit is contained in:
Martial BE
2023-12-29 10:59:26 +08:00
parent c4c89e8e1b
commit 211a862d54
5 changed files with 107 additions and 44 deletions

View File

@@ -26,21 +26,12 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
return
}
choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: aliResponse.Output.Text,
},
FinishReason: aliResponse.Output.FinishReason,
}
OpenAIResponse = types.ChatCompletionResponse{
ID: aliResponse.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: aliResponse.Model,
Choices: []types.ChatCompletionChoice{choice},
Choices: aliResponse.Output.ToChatCompletionChoices(),
Usage: &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
CompletionTokens: aliResponse.Usage.OutputTokens,
@@ -58,10 +49,31 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
messages := make([]AliMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
if request.Model != "qwen-vl-plus" {
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
} else {
openaiContent := message.ParseContent()
var parts []AliMessagePart
for _, part := range openaiContent {
if part.Type == types.ContentTypeText {
parts = append(parts, AliMessagePart{
Text: part.Text,
})
} else if part.Type == types.ContentTypeImageURL {
parts = append(parts, AliMessagePart{
Image: part.ImageURL.URL,
})
}
}
messages = append(messages, AliMessage{
Content: parts,
Role: strings.ToLower(message.Role),
})
}
}
enableSearch := false
@@ -77,6 +89,7 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
Messages: messages,
},
Parameters: AliParameters{
ResultFormat: "message",
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
@@ -87,6 +100,7 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream {
@@ -134,10 +148,15 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
// 阿里云响应转OpenAI响应
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
// chatChoice := aliResponse.Output.ToChatCompletionChoices()
// jsonBody, _ := json.MarshalIndent(chatChoice, "", " ")
// fmt.Println("requestBody:", string(jsonBody))
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.Index = aliResponse.Output.Choices[0].Index
choice.Delta.Content = aliResponse.Output.Choices[0].Message.StringContent()
// fmt.Println("choice.Delta.Content:", chatChoice[0].Message)
if aliResponse.Output.Choices[0].FinishReason != "null" {
finishReason := aliResponse.Output.Choices[0].FinishReason
choice.FinishReason = &finishReason
}
@@ -200,7 +219,8 @@ func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
// lastResponseText := ""
lastResponseText := ""
index := 0
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -216,9 +236,11 @@ func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
aliResponse.Model = model
aliResponse.Output.Choices[0].Index = index
index++
response := p.streamResponseAli2OpenAI(&aliResponse)
// response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
// lastResponseText = aliResponse.Output.Text
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Choices[0].Message.StringContent()
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())