From b5f6eaf159b22d2989f4148c6da753e5221c40fb Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 15 Apr 2024 09:34:20 +0800 Subject: [PATCH] opt: close the old connection for mj and sd clients --- api/handler/markmap_handler.go | 66 +++++++++++++++++++++------------- api/handler/mj_handler.go | 4 +++ api/handler/sd_handler.go | 4 +++ web/src/views/ImageMj.vue | 18 ++++++++-- web/src/views/ImageSd.vue | 18 ++++++++-- 5 files changed, 82 insertions(+), 28 deletions(-) diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index c4d749fd..e4e57620 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -3,10 +3,9 @@ package handler import ( "chatplus/core" "chatplus/core/types" - "chatplus/store/model" - "chatplus/store/vo" "chatplus/utils" - "chatplus/utils/resp" + "github.com/gorilla/websocket" + "net/http" "github.com/gin-gonic/gin" "gorm.io/gorm" @@ -15,34 +14,53 @@ import ( // MarkMapHandler 生成思维导图 type MarkMapHandler struct { BaseHandler + clients *types.LMap[uint, *types.WsClient] } func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { - return &MarkMapHandler{BaseHandler: BaseHandler{App: app, DB: db}} + return &MarkMapHandler{ + BaseHandler: BaseHandler{App: app, DB: db}, + clients: types.NewLMap[uint, *types.WsClient](), + } } -// GetModel get the chat model for generating Markdown text -func (h *MarkMapHandler) GetModel(c *gin.Context) { - modelId := h.App.SysConfig.XMindModelId - session := h.DB.Session(&gorm.Session{}).Where("enabled", true) - if modelId > 0 { - session = session.Where("id", modelId) - } else { - session = session.Where("platform", types.OpenAI) - } - var chatModel model.ChatModel - res := session.First(&chatModel) - if res.Error != nil { - resp.ERROR(c, "No available AI model") - return - } - - var modelVo vo.ChatModel - err := utils.CopyObject(chatModel, &modelVo) +func (h *MarkMapHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) if err != nil { - resp.ERROR(c, "error with copy object: "+err.Error()) + logger.Error(err) return } - resp.SUCCESS(c, modelVo) + modelId := h.GetInt(c, "model_id", 0) + userId := h.GetLoginUserId(c) + logger.Info(modelId) + client := types.NewWsClient(ws) + + // 保存会话连接 + h.clients.Put(userId, client) + go func() { + for { + _, msg, err := client.Receive() + if err != nil { + client.Close() + h.clients.Delete(userId) + return + } + + var message types.WsMessage + err = utils.JsonDecode(string(msg), &message) + if err != nil { + continue + } + + // 心跳消息 + if message.Type == "heartbeat" { + logger.Debug("收到 Chat 心跳消息:", message.Content) + continue + } + + logger.Info("Receive a message: ", message.Content) + + } + }() } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index e0e0f020..48511ecd 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -78,6 +78,10 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) + // close the existed connections + if cli := h.pool.Clients.Get(uint(userId)); cli != nil { + cli.Close() + } h.pool.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index b9c3625e..280799da 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -61,6 +61,10 @@ func (h *SdJobHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) + // close the existed connections + if cli := h.pool.Clients.Get(uint(userId)); cli != nil { + cli.Close() + } h.pool.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index d55c874b..70bb4839 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -797,9 +797,23 @@ const connect = () => { }); _socket.addEventListener('close', () => { - if (socket.value !== null) { + ElMessageBox.confirm( + '检测到您已经在其他客户端创建了新的连接,当前连接将被关闭!', + '提示', + { + dangerouslyUseHTMLString: true, + confirmButtonText: '重新连接', + cancelButtonText: '关闭', + type: 'warning', + } + ).then(() => { connect() - } + }).catch(() => { + ElMessage({ + type: 'info', + message: '连接已关闭', + }) + }) }); } diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index b34d0747..73bbc807 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -576,9 +576,23 @@ const connect = () => { }); _socket.addEventListener('close', () => { - if (socket.value !== null) { + ElMessageBox.confirm( + '检测到您已经在其他客户端创建了新的连接,当前连接将被关闭!', + '提示', + { + dangerouslyUseHTMLString: true, + confirmButtonText: '重新连接', + cancelButtonText: '关闭', + type: 'warning', + } + ).then(() => { connect() - } + }).catch(() => { + ElMessage({ + type: 'info', + message: '连接已关闭', + }) + }) }); }