feat: add websocket heartbeat message for mj page

This commit is contained in:
RockYang 2024-01-24 09:33:04 +08:00
parent 7fe4212684
commit 2113508b6d
4 changed files with 49 additions and 8 deletions

View File

@ -11,6 +11,8 @@ import (
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"github.com/gorilla/websocket"
"net/http"
"time"
"github.com/gin-gonic/gin"
@ -36,6 +38,27 @@ func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, man
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
if err != nil {

View File

@ -13,11 +13,14 @@ import (
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.SdConfigs {
if config.Enabled == false {
@ -26,7 +29,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%d", k)
service := NewService(name, 1, 300, config, queue, db, manager)
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
// run sd service
go func() {
service.Run()
@ -36,8 +39,10 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
}
return &ServicePool{
taskQueue: queue,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}

View File

@ -24,6 +24,7 @@ type Service struct {
httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
@ -33,12 +34,13 @@ type Service struct {
taskTimeout int64
}
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, queue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
return &Service{
name: name,
config: config,
httpClient: req.C(),
taskQueue: queue,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
db: db,
uploadManager: manager,
taskTimeout: timeout,
@ -73,6 +75,8 @@ func (s *Service) Run() {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
// 通知前端,任务失败
s.notifyQueue.RPush(task.UserId)
continue
}

View File

@ -563,6 +563,7 @@ const translatePrompt = () => {
})
}
const heartbeatHandle = ref(null)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
@ -575,6 +576,14 @@ const connect = () => {
const _socket = new WebSocket(host + `/api/mj/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
//
clearInterval(heartbeatHandle.value)
heartbeatHandle.value = setInterval(() => {
if (socket.value !== null) {
socket.value.send(JSON.stringify({type: "heartbeat", content: "ping"}))
}
}, 5000);
});
_socket.addEventListener('message', event => {