️ feat: channel support plugin settings (#89)

This commit is contained in:
Buer
2024-03-08 14:49:33 +08:00
committed by GitHub
parent 41134576f2
commit d8d880bf85
17 changed files with 243 additions and 65 deletions

View File

@@ -15,8 +15,6 @@ type aliStreamHandler struct {
lastStreamResponse string
}
const AliEnableSearchModelSuffix = "-internet"
func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getAliChatRequest(request)
if errWithCode != nil {
@@ -70,7 +68,7 @@ func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*
headers["X-DashScope-SSE"] = "enable"
}
aliRequest := convertFromChatOpenai(request)
aliRequest := p.convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
if err != nil {
@@ -110,7 +108,7 @@ func (p *AliProvider) convertToChatOpenai(response *AliChatResponse, request *ty
}
// 阿里云聊天请求体
func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest {
func (p *AliProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -141,24 +139,35 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
}
return &AliChatRequest{
Model: aliModel,
aliChatRequest := &AliChatRequest{
Model: request.Model,
Input: AliInput{
Messages: messages,
},
Parameters: AliParameters{
ResultFormat: "message",
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
}
p.pluginHandle(aliChatRequest)
return aliChatRequest
}
func (p *AliProvider) pluginHandle(request *AliChatRequest) {
if p.Channel.Plugin == nil {
return
}
plugin := p.Channel.Plugin.Data()
// 检测是否开启了 web_search 插件
if pWeb, ok := plugin["web_search"]; ok {
if enable, ok := pWeb["enable"].(bool); ok && enable {
request.Parameters.EnableSearch = true
}
}
}
// 转换为OpenAI聊天流式请求体
@@ -185,11 +194,11 @@ func (h *aliStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string,
return
}
h.convertToOpenaiStream(&aliResponse, dataChan, errChan)
h.convertToOpenaiStream(&aliResponse, dataChan)
}
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, dataChan chan string, errChan chan error) {
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, dataChan chan string) {
content := aliResponse.Output.Choices[0].Message.StringContent()
var choice types.ChatCompletionStreamChoice

View File

@@ -220,10 +220,10 @@ func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, dataChan chan stri
return
}
h.convertToOpenaiStream(miniResponse, dataChan, errChan)
h.convertToOpenaiStream(miniResponse, dataChan)
}
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, dataChan chan string, errChan chan error) {
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, dataChan chan string) {
streamResponse := types.ChatCompletionStreamResponse{
ID: miniResponse.RequestID,
Object: "chat.completion.chunk",

View File

@@ -165,11 +165,11 @@ func (h *palmStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string,
return
}
h.convertToOpenaiStream(&palmChatResponse, dataChan, errChan)
h.convertToOpenaiStream(&palmChatResponse, dataChan)
}
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, dataChan chan string, errChan chan error) {
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, dataChan chan string) {
var choice types.ChatCompletionStreamChoice
if len(palmChatResponse.Candidates) > 0 {
choice.Delta.Content = palmChatResponse.Candidates[0].Content

View File

@@ -180,11 +180,11 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan stri
return
}
h.convertToOpenaiStream(&tencentChatResponse, dataChan, errChan)
h.convertToOpenaiStream(&tencentChatResponse, dataChan)
}
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string, errChan chan error) {
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string) {
streamResponse := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),

View File

@@ -256,7 +256,7 @@ func (h *xunfeiHandler) handlerStream(rawLine *[]byte, dataChan chan string, err
return
}
h.convertToOpenaiStream(xunfeiChatResponse, dataChan, errChan)
h.convertToOpenaiStream(xunfeiChatResponse, dataChan)
if isFinished {
errChan <- io.EOF
@@ -264,7 +264,7 @@ func (h *xunfeiHandler) handlerStream(rawLine *[]byte, dataChan chan string, err
}
}
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, dataChan chan string, errChan chan error) {
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, dataChan chan string) {
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}

View File

@@ -73,7 +73,6 @@ func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = p.getZhipuToken()
return headers
}

View File

@@ -69,7 +69,7 @@ func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*h
// 获取请求头
headers := p.GetRequestHeaders()
zhipuRequest := convertFromChatOpenai(request)
zhipuRequest := p.convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(zhipuRequest), p.Requester.WithHeader(headers))
@@ -94,7 +94,7 @@ func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *ty
ID: response.ID,
Object: "chat.completion",
Created: response.Created,
Model: response.Model,
Model: request.Model,
Choices: response.Choices,
Usage: response.Usage,
}
@@ -104,7 +104,7 @@ func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *ty
return
}
func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
func (p *ZhipuProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
for i := range request.Messages {
request.Messages[i].Role = convertRole(request.Messages[i].Role)
}
@@ -153,7 +153,7 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
for _, function := range request.Functions {
zhipuRequest.Tools = append(zhipuRequest.Tools, ZhipuTool{
Type: "function",
Function: *function,
Function: function,
})
}
} else if request.Tools != nil {
@@ -161,14 +161,57 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
for _, tool := range request.Tools {
zhipuRequest.Tools = append(zhipuRequest.Tools, ZhipuTool{
Type: "function",
Function: tool.Function,
Function: &tool.Function,
})
}
}
p.pluginHandle(zhipuRequest)
return zhipuRequest
}
func (p *ZhipuProvider) pluginHandle(request *ZhipuRequest) {
if p.Channel.Plugin == nil {
return
}
plugin := p.Channel.Plugin.Data()
// 检测是否开启了 retrieval 插件
if pRetrieval, ok := plugin["retrieval"]; ok {
if knowledge_id, ok := pRetrieval["knowledge_id"].(string); ok && knowledge_id != "" {
retrieval := ZhipuTool{
Type: "retrieval",
Retrieval: &ZhipuRetrieval{
KnowledgeId: knowledge_id,
},
}
if prompt_template, ok := pRetrieval["prompt_template"].(string); ok && prompt_template != "" {
retrieval.Retrieval.PromptTemplate = prompt_template
}
request.Tools = append(request.Tools, retrieval)
// 如果开启了 retrieval 插件web_search 无效
return
}
}
// 检测是否开启了 web_search 插件
if pWeb, ok := plugin["web_search"]; ok {
if enable, ok := pWeb["enable"].(bool); ok && enable {
request.Tools = append(request.Tools, ZhipuTool{
Type: "web_search",
WebSearch: &ZhipuWebSearch{
Enable: true,
},
})
}
}
}
// 转换为OpenAI聊天流式请求体
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
@@ -198,10 +241,10 @@ func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string
return
}
h.convertToOpenaiStream(zhipuResponse, dataChan, errChan)
h.convertToOpenaiStream(zhipuResponse, dataChan)
}
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, dataChan chan string, errChan chan error) {
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, dataChan chan string) {
streamResponse := types.ChatCompletionStreamResponse{
ID: zhipuResponse.ID,
Object: "chat.completion.chunk",

View File

@@ -37,10 +37,10 @@ func (p *ZhipuProvider) CreateImageGenerations(request *types.ImageRequest) (*ty
return nil, errWithCode
}
return p.convertToImageOpenai(zhipuResponse, request)
return p.convertToImageOpenai(zhipuResponse)
}
func (p *ZhipuProvider) convertToImageOpenai(response *ZhipuImageGenerationResponse, request *types.ImageRequest) (openaiResponse *types.ImageResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
func (p *ZhipuProvider) convertToImageOpenai(response *ZhipuImageGenerationResponse) (openaiResponse *types.ImageResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
error := errorHandle(&response.Error)
if error != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{

View File

@@ -10,10 +10,16 @@ type ZhipuWebSearch struct {
SearchQuery string `json:"search_query,omitempty"`
}
type ZhipuRetrieval struct {
KnowledgeId string `json:"knowledge_id"`
PromptTemplate string `json:"prompt_template,omitempty"`
}
type ZhipuTool struct {
Type string `json:"type"`
Function types.ChatCompletionFunction `json:"function"`
WebSearch string `json:"web_search,omitempty"`
Type string `json:"type"`
Function *types.ChatCompletionFunction `json:"function,omitempty"`
WebSearch *ZhipuWebSearch `json:"web_search,omitempty"`
Retrieval *ZhipuRetrieval `json:"retrieval,omitempty"`
}
type ZhipuRequest struct {
Model string `json:"model"`