diff --git a/api/core/types/chat.go b/api/core/types/chat.go index c11c78d2..175cd3ab 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -12,6 +12,9 @@ type ApiRequest struct { Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台 ToolChoice string `json:"tool_choice,omitempty"` + + Input map[string]interface{} `json:"input,omitempty"` //兼容阿里通义千问 + Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问 } type Message struct { diff --git a/api/core/types/config.go b/api/core/types/config.go index a3b0c92c..bf1cc5e2 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -156,6 +156,7 @@ const Azure = Platform("Azure") const ChatGLM = Platform("ChatGLM") const Baidu = Platform("Baidu") const XunFei = Platform("XunFei") +const Ali = Platform("Ali") // UserChatConfig 用户的聊天配置 type UserChatConfig struct { diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 2ce9950c..27dc8726 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -267,6 +267,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio req.Temperature = h.App.ChatConfig.XunFei.Temperature req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens break + case types.Ali: + req.Input = map[string]interface{}{"messages": []map[string]string{{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}}} + req.Parameters = map[string]interface{}{} + break default: utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!") utils.ReplyMessage(ws, ErrImg) @@ -340,7 +344,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) case types.XunFei: return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - + case types.Ali: + return h.sendQwenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) } utils.ReplyChunkMessage(ws, types.WsMessage{ Type: types.WsMiddle, @@ -434,6 +439,10 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf case types.Baidu: apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) break + case types.Ali: + apiURL = apiKey.ApiURL + req.Messages = nil + break default: if req.Model == "gpt-4-all" || strings.HasPrefix(req.Model, "gpt-4-gizmo-g-") { apiURL = "https://gpt.bemore.lol/v1/chat/completions" @@ -496,6 +505,11 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf request.RequestURI = "" case types.OpenAI: request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) + break + case types.Ali: + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) + request.Header.Set("X-DashScope-SSE", "enable") + break } return client.Do(request) } diff --git a/api/handler/chatimpl/qwen_handler.go b/api/handler/chatimpl/qwen_handler.go new file mode 100644 index 00000000..9a6fd505 --- /dev/null +++ b/api/handler/chatimpl/qwen_handler.go @@ -0,0 +1,228 @@ +package chatimpl + +import ( + "bufio" + "chatplus/core/types" + "chatplus/store/model" + "chatplus/store/vo" + "chatplus/utils" + "context" + "encoding/json" + "fmt" + "html/template" + "io" + "strings" + "time" + "unicode/utf8" +) + +type qwenResp struct { + Output struct { + FinishReason string `json:"finish_reason"` + Text string `json:"text"` + } `json:"output"` + Usage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + RequestID string `json:"request_id"` +} + +// 通义千问消息发送实现 +func (h *ChatHandler) sendQwenMessage( + chatCtx []interface{}, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + var apiKey = model.ApiKey{} + response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) + if err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + return nil + } else if strings.Contains(err.Error(), "no available key") { + utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + return nil + } else { + logger.Error(err) + } + + utils.ReplyMessage(ws, ErrorMsg) + utils.ReplyMessage(ws, ErrImg) + return err + } else { + defer response.Body.Close() + } + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + scanner := bufio.NewScanner(response.Body) + + var content, lastText, newText string + + for scanner.Scan() { + line := scanner.Text() + if len(line) < 5 || strings.HasPrefix(line, "id:") || + strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") { + continue + } + if strings.HasPrefix(line, "data:") { + content = line[5:] + } + // 处理代码换行 + if len(content) == 0 { + content = "\n" + } + + var resp qwenResp + err := utils.JsonDecode(content, &resp) + if err != nil { + logger.Error("error with parse data line: ", err) + utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) + break + } + if len(contents) == 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + } + + //通过比较 lastText(上一次的文本)和 currentText(当前的文本), + //提取出新添加的文本部分。然后只将这部分新文本发送到客户端。 + //每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。 + currentText := resp.Output.Text + if currentText != lastText { + // 找出新增的文本部分并发送 + newText = strings.Replace(currentText, lastText, "", 1) + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(newText), + }) // 发送新增的文本部分到客户端 + lastText = currentText // 更新 lastText + } + contents = append(contents, newText) + + if resp.Output.FinishReason == "stop" { + break + } + + } //end for + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } + + // 消息发送成功 + if len(contents) > 0 { + // 更新用户的对话次数 + h.subUserCalls(userVo, session) + + if message.Role == "" { + message.Role = "assistant" + } + message.Content = strings.Join(contents, "") + useMsg := types.Message{Role: "user", Content: prompt} + + // 更新上下文消息,如果是调用函数则不需要更新上下文 + if h.App.ChatConfig.EnableContext { + chatCtx = append(chatCtx, useMsg) // 提问消息 + chatCtx = append(chatCtx, message) // 回复消息 + h.App.ChatContexts.Put(session.ChatId, chatCtx) + } + + // 追加聊天记录 + if h.App.ChatConfig.EnableHistory { + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) + } + historyUserMsg := model.HistoryMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyToken, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyToken + getTotalTokens(req) + historyReplyMsg := model.HistoryMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + // 更新用户信息 + h.incUserTokenFee(userVo.Id, totalTokens) + } + + // 保存当前会话 + var chatItem model.ChatItem + res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + if res.Error != nil { + chatItem.ChatId = session.ChatId + chatItem.UserId = session.UserId + chatItem.RoleId = role.Id + chatItem.ModelId = session.Model.Id + if utf8.RuneCountInString(prompt) > 30 { + chatItem.Title = string([]rune(prompt)[:30]) + "..." + } else { + chatItem.Title = prompt + } + h.db.Create(&chatItem) + } + } + } else { + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("error with reading response: %v", err) + } + + var res struct { + Code int `json:"error_code"` + Msg string `json:"error_msg"` + } + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("error with decode response: %v", err) + } + utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg) + } + + return nil +} diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index c283967b..5f44bd53 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -175,6 +175,11 @@ const platforms = ref([ value: "Azure", api_url: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15" }, + { + name: "【阿里】千义通问", + value: "Qwen", + api_url: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + }, ]) const types = ref([ {name: "聊天", value: "chat"}, @@ -272,7 +277,7 @@ const changePlatform = () => { .opt-box { padding-bottom: 10px; - display flex; + display: flex; justify-content flex-end .el-icon { @@ -289,7 +294,7 @@ const changePlatform = () => { .el-form-item__content { .el-icon { - padding-left 10px; + padding-left: 10px; } } } diff --git a/web/src/views/admin/ChatModel.vue b/web/src/views/admin/ChatModel.vue index c5611534..5b3b7d79 100644 --- a/web/src/views/admin/ChatModel.vue +++ b/web/src/views/admin/ChatModel.vue @@ -134,6 +134,8 @@ const platforms = ref([ {name: "【清华智普】ChatGLM", value: "ChatGLM"}, {name: "【百度】文心一言", value: "Baidu"}, {name: "【微软】Azure", value: "Azure"}, + {name: "【阿里】通义千问", value: "Ali"}, + ]) // 获取数据 @@ -233,12 +235,12 @@ const remove = function (row) {