From 7ad41927aa93970de2eda9373897eb0fbee1121e Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 15 Apr 2024 17:23:59 +0800 Subject: [PATCH] feat: markmap function is ready --- CHANGELOG.md | 4 + api/core/app_server.go | 2 +- api/core/types/web.go | 2 +- api/handler/chatimpl/chat_handler.go | 3 +- api/handler/markmap_handler.go | 180 +++++++++++++++++++++++++-- api/main.go | 2 +- api/service/mj/plus_client.go | 7 +- database/update-v4.0.3.sql | 3 +- web/src/assets/css/mark-map.styl | 33 +++++ web/src/views/ImageMj.vue | 114 ++++------------- web/src/views/ImageSd.vue | 36 +++--- web/src/views/MarkMap.vue | 174 ++++++++++++++++++++++---- 12 files changed, 418 insertions(+), 142 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b134612..89175d99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ * Bug修复:修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题 * 功能新增:支持上传图片和视觉模型 * 功能优化:优化聊天页面的复制代码按钮样式乱码 +* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图 +* 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用GPT3.5或者 GPT4 +* 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。 +* 功能新增:支持管理后台 Logo 修改 ## 4.0.2 diff --git a/api/core/app_server.go b/api/core/app_server.go index 69645add..5c9d2ad6 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -218,7 +218,7 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/config/get" || c.Request.URL.Path == "/api/product/list" || c.Request.URL.Path == "/api/menu/list" || - c.Request.URL.Path == "/api/markMap/model" || + c.Request.URL.Path == "/api/markMap/client" || strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/function/") || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || diff --git a/api/core/types/web.go b/api/core/types/web.go index 601612fa..041a9859 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -21,7 +21,7 @@ const ( WsStart = WsMsgType("start") WsMiddle = WsMsgType("middle") WsEnd = WsMsgType("end") - WsMjImg = WsMsgType("mj") + WsErr = WsMsgType("error") ) type BizCode int diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 08785752..dbb9f682 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -525,7 +525,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi request = request.WithContext(ctx) request.Header.Set("Content-Type", "application/json") - var proxyURL string if len(apiKey.ProxyURL) > 5 { // 使用代理 proxy, _ := url.Parse(apiKey.ProxyURL) client = &http.Client{ @@ -536,7 +535,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi } else { client = http.DefaultClient } - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, proxyURL, req.Model) + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) switch session.Model.Platform { case types.Azure: request.Header.Set("api-key", apiKey.Value) diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index e4e57620..794eb284 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -1,26 +1,35 @@ package handler import ( + "bufio" + "bytes" "chatplus/core" "chatplus/core/types" + "chatplus/store/model" "chatplus/utils" - "github.com/gorilla/websocket" - "net/http" - + "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "gorm.io/gorm" + "io" + "net/http" + "net/url" + "strings" + "time" ) // MarkMapHandler 生成思维导图 type MarkMapHandler struct { BaseHandler - clients *types.LMap[uint, *types.WsClient] + clients *types.LMap[int, *types.WsClient] } func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { return &MarkMapHandler{ BaseHandler: BaseHandler{App: app, DB: db}, - clients: types.NewLMap[uint, *types.WsClient](), + clients: types.NewLMap[int, *types.WsClient](), } } @@ -32,9 +41,13 @@ func (h *MarkMapHandler) Client(c *gin.Context) { } modelId := h.GetInt(c, "model_id", 0) - userId := h.GetLoginUserId(c) + userId := h.GetInt(c, "user_id", 0) logger.Info(modelId) + client := types.NewWsClient(ws) + if cli := h.clients.Get(userId); cli != nil { + cli.Close() + } // 保存会话连接 h.clients.Put(userId, client) @@ -55,12 +68,165 @@ func (h *MarkMapHandler) Client(c *gin.Context) { // 心跳消息 if message.Type == "heartbeat" { - logger.Debug("收到 Chat 心跳消息:", message.Content) + logger.Debug("收到 MarkMap 心跳消息:", message.Content) + continue + } + // change model + if message.Type == "model_id" { + modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0) continue } logger.Info("Receive a message: ", message.Content) + err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId) + if err != nil { + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()}) + } } }() } + +func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error { + var user model.User + res := h.DB.Model(&model.User{}).First(&user, userId) + if res.Error != nil { + return fmt.Errorf("error with query user info: %v", res.Error) + } + var chatModel model.ChatModel + res = h.DB.Where("id", modelId).First(&chatModel) + if res.Error != nil { + return fmt.Errorf("error with query chat model: %v", res.Error) + } + + if user.Status == false { + return errors.New("当前用户被禁用") + } + + if user.Power < chatModel.Power { + return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power) + } + + messages := make([]interface{}, 0) + messages = append(messages, types.Message{Role: "system", Content: "你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。只输出 Markdown 内容,不要输出任何解释性的语句。"}) + messages = append(messages, types.Message{Role: "user", Content: prompt}) + var req = types.ApiRequest{ + Model: chatModel.Value, + Stream: true, + Messages: messages, + } + + var apiKey model.ApiKey + response, err := h.doRequest(req, chatModel, &apiKey) + if err != nil { + return fmt.Errorf("请求 OpenAI API 失败: %s", err) + } + + defer response.Body.Close() + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + // 循环读取 Chunk 消息 + var message = types.Message{} + scanner := bufio.NewScanner(response.Body) + var isNew = true + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, "data:") || len(line) < 30 { + continue + } + + var responseBody = types.ApiResponse{} + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 + return fmt.Errorf("error with decode data: %v", err) + } + + // 初始化 role + if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { + message.Role = responseBody.Choices[0].Delta.Role + continue + } else if responseBody.Choices[0].FinishReason != "" { + break // 输出完成或者输出中断了 + } else { + if isNew { + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart}) + isNew = false + } + utils.ReplyChunkMessage(client, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + }) + } + } // end for + + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) + + } else { + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("读取响应失败: %v", err) + } + var res types.ApiError + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("解析响应失败: %v", err) + } + + // OpenAI API 调用异常处理 + if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { + // remove key + h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{}) + return errors.New("请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") + } else if strings.Contains(res.Error.Message, "You exceeded your current quota") { + return errors.New("请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") + } else { + return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message) + } + } + + return nil +} + +func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) { + // if the chat model bind a KEY, use it directly + var res *gorm.DB + if chatModel.KeyId > 0 { + res = h.DB.Where("id", chatModel.KeyId).Find(apiKey) + } + // use the last unused key + if res.Error != nil { + res = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) + } + if res.Error != nil { + return nil, errors.New("no available key, please import key") + } + apiURL := apiKey.ApiURL + // 更新 API KEY 的最后使用时间 + h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + + // 创建 HttpClient 请求对象 + var client *http.Client + requestBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/json") + if len(apiKey.ProxyURL) > 5 { // 使用代理 + proxy, _ := url.Parse(apiKey.ProxyURL) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + } else { + client = http.DefaultClient + } + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) + return client.Do(request) +} diff --git a/api/main.go b/api/main.go index 4cc2d5c5..586b7f70 100644 --- a/api/main.go +++ b/api/main.go @@ -439,7 +439,7 @@ func main() { fx.Provide(handler.NewMarkMapHandler), fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { group := s.Engine.Group("/api/markMap/") - group.GET("model", h.GetModel) + group.Any("client", h.Client) }), fx.Invoke(func(s *core.AppServer, db *gorm.DB) { go func() { diff --git a/api/service/mj/plus_client.go b/api/service/mj/plus_client.go index 52846208..822d4b91 100644 --- a/api/service/mj/plus_client.go +++ b/api/service/mj/plus_client.go @@ -73,6 +73,7 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { // Blend 融图 func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) body := ImageReq{ BotType: "MID_JOURNEY", Dimensions: "SQUARE", @@ -163,7 +164,8 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), "taskId": task.MessageId, } - apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.Config.Mode, c.apiURL) + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) var res ImageRes var errRes ErrRes r, err := c.client.R(). @@ -189,7 +191,8 @@ func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), "taskId": task.MessageId, } - apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.Config.Mode, c.apiURL) + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) var res ImageRes var errRes ErrRes r, err := req.C().R(). diff --git a/database/update-v4.0.3.sql b/database/update-v4.0.3.sql index 219c4187..80250e73 100644 --- a/database/update-v4.0.3.sql +++ b/database/update-v4.0.3.sql @@ -1,2 +1,3 @@ ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`; -ALTER TABLE `chatgpt_chat_models` ADD `key_id` INT(11) NOT NULL COMMENT '绑定API KEY ID' AFTER `open`; \ No newline at end of file +ALTER TABLE `chatgpt_chat_models` ADD `key_id` INT(11) NOT NULL COMMENT '绑定API KEY ID' AFTER `open`; +INSERT INTO `chatgpt_plus`.`chatgpt_menus`(`id`, `name`, `icon`, `url`, `sort_num`, `enabled`) VALUES (12, '思维导图', '/images/menu/xmind.png', '/xmind', 3, 1); \ No newline at end of file diff --git a/web/src/assets/css/mark-map.styl b/web/src/assets/css/mark-map.styl index 943e33ae..d6d11f2a 100644 --- a/web/src/assets/css/mark-map.styl +++ b/web/src/assets/css/mark-map.styl @@ -66,10 +66,43 @@ .right-box { width 100% + h2 { color #ffffff } + .markdown { + color #ffffff + display flex + justify-content center + align-items center + + h1 { + color: #47fff1; + } + + h2 { + color: #ffcc00; + } + + ul { + list-style-type: disc; + margin-left: 20px; + + li { + line-height 1.5 + } + } + + strong { + font-weight: bold; + } + + em { + font-style: italic; + } + } + .body { display flex justify-content center diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 70bb4839..12952bf7 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -525,42 +525,10 @@
    -
  • - - U1 - -
  • -
  • - - U2 - -
  • -
  • - - U3 - -
  • -
  • - - U4 - -
  • +
  • U1
  • +
  • U2
  • +
  • U3
  • +
  • U4
  • @@ -586,42 +554,10 @@
@@ -797,23 +733,25 @@ const connect = () => { }); _socket.addEventListener('close', () => { - ElMessageBox.confirm( - '检测到您已经在其他客户端创建了新的连接,当前连接将被关闭!', - '提示', - { - dangerouslyUseHTMLString: true, - confirmButtonText: '重新连接', - cancelButtonText: '关闭', - type: 'warning', - } - ).then(() => { - connect() - }).catch(() => { - ElMessage({ - type: 'info', - message: '连接已关闭', + if (socket.value !== null) { + ElMessageBox.confirm( + '检测到您已经在其他客户端创建了新的连接,当前连接将被关闭!', + '提示', + { + dangerouslyUseHTMLString: true, + confirmButtonText: '重新连接', + cancelButtonText: '关闭', + type: 'warning', + } + ).then(() => { + connect() + }).catch(() => { + ElMessage({ + type: 'info', + message: '连接已关闭', + }) }) - }) + } }); } diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index 73bbc807..2bbc3fa0 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -576,24 +576,26 @@ const connect = () => { }); _socket.addEventListener('close', () => { - ElMessageBox.confirm( - '检测到您已经在其他客户端创建了新的连接,当前连接将被关闭!', - '提示', - { - dangerouslyUseHTMLString: true, - confirmButtonText: '重新连接', - cancelButtonText: '关闭', - type: 'warning', - } - ).then(() => { - connect() - }).catch(() => { - ElMessage({ - type: 'info', - message: '连接已关闭', + if (socket.value !== null) { + ElMessageBox.confirm( + '检测到您已经在其他客户端创建了新的连接,当前连接将被关闭!', + '提示', + { + dangerouslyUseHTMLString: true, + confirmButtonText: '重新连接', + cancelButtonText: '关闭', + type: 'warning', + } + ).then(() => { + connect() + }).catch(() => { + ElMessage({ + type: 'info', + message: '连接已关闭', + }) }) - }) - }); + } + }) } const clipboard = ref(null) diff --git a/web/src/views/MarkMap.vue b/web/src/views/MarkMap.vue index 32aedeb8..53e5f8f8 100644 --- a/web/src/views/MarkMap.vue +++ b/web/src/views/MarkMap.vue @@ -23,7 +23,7 @@ 请选择生成思维导图的AI模型
- +
- 当前可用算力:{{ power }} + 当前可用算力:{{ loginUser.power }}
- 智能生成思维导图 + + 智能生成思维导图 +
@@ -52,7 +54,7 @@
- 直接生成(免费) + 直接生成(免费)
@@ -69,7 +71,10 @@

思维导图

-
+
+
+
+
@@ -83,13 +88,13 @@