From 2113508b6d22c88f9a62e5ff217ba6252fc0ce2a Mon Sep 17 00:00:00 2001 From: RockYang Date: Wed, 24 Jan 2024 09:33:04 +0800 Subject: [PATCH] feat: add websocket heartbeat message for mj page --- api/handler/sd_handler.go | 23 +++++++++++++++++++++++ api/service/sd/pool.go | 17 +++++++++++------ api/service/sd/service.go | 8 ++++++-- web/src/views/ImageMj.vue | 9 +++++++++ 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 171a57fd..37a0e6e4 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -11,6 +11,8 @@ import ( "chatplus/utils/resp" "encoding/base64" "fmt" + "github.com/gorilla/websocket" + "net/http" "time" "github.com/gin-gonic/gin" @@ -36,6 +38,27 @@ func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, man return &h } +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *SdJobHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.pool.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) +} + func (h *SdJobHandler) checkLimits(c *gin.Context) bool { user, err := utils.GetLoginUser(c, h.db) if err != nil { diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go index b31aedf9..131b66c1 100644 --- a/api/service/sd/pool.go +++ b/api/service/sd/pool.go @@ -11,13 +11,16 @@ import ( ) type ServicePool struct { - services []*Service - taskQueue *store.RedisQueue + services []*Service + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + Clients *types.LMap[uint, *types.WsClient] // UserId => Client } func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { services := make([]*Service, 0) - queue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) + taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) + notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli) // create mj client and service for k, config := range appConfig.SdConfigs { if config.Enabled == false { @@ -26,7 +29,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa // create sd service name := fmt.Sprintf("StableDifffusion Service-%d", k) - service := NewService(name, 1, 300, config, queue, db, manager) + service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager) // run sd service go func() { service.Run() @@ -36,8 +39,10 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa } return &ServicePool{ - taskQueue: queue, - services: services, + taskQueue: taskQueue, + notifyQueue: notifyQueue, + services: services, + Clients: types.NewLMap[uint, *types.WsClient](), } } diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 1741028f..c090583f 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -24,6 +24,7 @@ type Service struct { httpClient *req.Client config types.StableDiffusionConfig taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue db *gorm.DB uploadManager *oss.UploaderManager name string // service name @@ -33,12 +34,13 @@ type Service struct { taskTimeout int64 } -func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, queue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service { +func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service { return &Service{ name: name, config: config, httpClient: req.C(), - taskQueue: queue, + taskQueue: taskQueue, + notifyQueue: notifyQueue, db: db, uploadManager: manager, taskTimeout: timeout, @@ -73,6 +75,8 @@ func (s *Service) Run() { s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) // release task num atomic.AddInt32(&s.handledTaskNum, -1) + // 通知前端,任务失败 + s.notifyQueue.RPush(task.UserId) continue } diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 5e3b36e8..6a909aab 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -563,6 +563,7 @@ const translatePrompt = () => { }) } +const heartbeatHandle = ref(null) const connect = () => { let host = process.env.VUE_APP_WS_HOST if (host === '') { @@ -575,6 +576,14 @@ const connect = () => { const _socket = new WebSocket(host + `/api/mj/client?user_id=${userId.value}`); _socket.addEventListener('open', () => { socket.value = _socket; + + // 发送心跳消息 + clearInterval(heartbeatHandle.value) + heartbeatHandle.value = setInterval(() => { + if (socket.value !== null) { + socket.value.send(JSON.stringify({type: "heartbeat", content: "ping"})) + } + }, 5000); }); _socket.addEventListener('message', event => {