From d02cb573fd31be9eed39dcdb90c9d5bb3a880ad1 Mon Sep 17 00:00:00 2001 From: RockYang Date: Sat, 20 Apr 2024 21:30:55 +0800 Subject: [PATCH] DO NOT refresh finished jobs when job is running --- api/service/mj/pool.go | 11 +- api/service/mj/service.go | 11 +- api/service/sd/pool.go | 8 +- api/service/sd/service.go | 8 +- api/service/sd/types.go | 48 +-- web/src/router.js | 6 + web/src/views/Dalle.vue | 754 ++++++++++++++++++++++++++++++++++++++ web/src/views/ImageMj.vue | 16 +- web/src/views/ImageSd.vue | 16 +- web/src/views/MarkMap.vue | 2 +- 10 files changed, 816 insertions(+), 64 deletions(-) create mode 100644 web/src/views/Dalle.vue diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index 7404021e..5fcb681d 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -4,6 +4,7 @@ import ( "chatplus/core/types" logger2 "chatplus/logger" "chatplus/service/oss" + "chatplus/service/sd" "chatplus/store" "chatplus/store/model" "fmt" @@ -69,16 +70,16 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa func (p *ServicePool) CheckTaskNotify() { go func() { for { - var userId uint - err := p.notifyQueue.LPop(&userId) + var message sd.NotifyMessage + err := p.notifyQueue.LPop(&message) if err != nil { continue } - cli := p.Clients.Get(userId) + cli := p.Clients.Get(uint(message.UserId)) if cli == nil { continue } - err = cli.Send([]byte("Task Updated")) + err = cli.Send([]byte(message.Message)) if err != nil { continue } @@ -127,7 +128,7 @@ func (p *ServicePool) DownloadImages() { if cli == nil { continue } - err = cli.Send([]byte("Task Updated")) + err = cli.Send([]byte(sd.Finished)) if err != nil { continue } diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 2c782d41..0d5f0dea 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -3,6 +3,7 @@ package mj import ( "chatplus/core/types" "chatplus/service" + "chatplus/service/sd" "chatplus/store" "chatplus/store/model" "chatplus/utils" @@ -105,7 +106,7 @@ func (s *Service) Run() { // update the task progress s.db.Updates(&job) // 任务失败,通知前端 - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed}) continue } logger.Infof("任务提交成功:%+v", res) @@ -147,7 +148,7 @@ func (s *Service) Notify(job model.MidJourneyJob) error { "progress": -1, "err_msg": task.FailReason, }) - s.notifyQueue.RPush(job.UserId) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) return fmt.Errorf("task failed: %v", task.FailReason) } @@ -166,7 +167,11 @@ func (s *Service) Notify(job model.MidJourneyJob) error { } // 通知前端更新任务进度 if oldProgress != job.Progress { - s.notifyQueue.RPush(job.UserId) + message := sd.Running + if job.Progress == 100 { + message = sd.Finished + } + s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message}) } return nil } diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go index 3033b548..50c39a6b 100644 --- a/api/service/sd/pool.go +++ b/api/service/sd/pool.go @@ -60,16 +60,16 @@ func (p *ServicePool) CheckTaskNotify() { go func() { logger.Info("Running Stable-Diffusion task notify checking ...") for { - var userId uint - err := p.notifyQueue.LPop(&userId) + var message NotifyMessage + err := p.notifyQueue.LPop(&message) if err != nil { continue } - client := p.Clients.Get(userId) + client := p.Clients.Get(uint(message.UserId)) if client == nil { continue } - err = client.Send([]byte("Task Updated")) + err = client.Send([]byte(message.Message)) if err != nil { continue } diff --git a/api/service/sd/service.go b/api/service/sd/service.go index f6d3b081..7ab82c95 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -81,7 +81,7 @@ func (s *Service) Run() { "err_msg": err.Error(), }) // 通知前端,任务失败 - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) continue } } @@ -189,13 +189,13 @@ func (s *Service) Txt2Img(task types.SdTask) error { "progress": -1, "err_msg": err.Error(), }) - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) return err } // task finished s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) // 从 leveldb 中删除预览图片数据 _ = s.leveldb.Delete(task.Params.TaskId) return nil @@ -205,7 +205,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { if err == nil && resp.Progress > 0 { s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) // 发送更新状态信号 - s.notifyQueue.RPush(task.UserId) + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) // 保存预览图片数据 if resp.CurrentImage != "" { _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) diff --git a/api/service/sd/types.go b/api/service/sd/types.go index 56ebb5bd..eb172bcd 100644 --- a/api/service/sd/types.go +++ b/api/service/sd/types.go @@ -4,44 +4,14 @@ import logger2 "chatplus/logger" var logger = logger2.GetLogger() -type TaskInfo struct { - UserId uint `json:"user_id"` - SessionId string `json:"session_id"` - JobId int `json:"job_id"` - TaskId string `json:"task_id"` - Data []interface{} `json:"data"` - EventData interface{} `json:"event_data"` - FnIndex int `json:"fn_index"` - SessionHash string `json:"session_hash"` +type NotifyMessage struct { + UserId int `json:"user_id"` + JobId int `json:"job_id"` + Message string `json:"message"` } -type CBReq struct { - UserId uint - SessionId string - JobId int - TaskId string - ImageName string - ImageData string - Progress int - Seed int64 - Success bool - Message string -} - -var ParamKeys = map[string]int{ - "task_id": 0, - "prompt": 1, - "negative_prompt": 2, - "steps": 4, - "sampler": 5, - "face_fix": 7, // 面部修复 - "cfg_scale": 8, - "seed": 27, - "height": 10, - "width": 9, - "hd_fix": 11, - "hd_redraw_rate": 12, //高清修复重绘幅度 - "hd_scale": 13, // 高清修复放大倍数 - "hd_scale_alg": 14, // 高清修复放大算法 - "hd_sample_num": 15, // 高清修复采样次数 -} +const ( + Running = "RUNNING" + Finished = "FINISH" + Failed = "FAIL" +) diff --git a/web/src/router.js b/web/src/router.js index b55655f8..2454a974 100644 --- a/web/src/router.js +++ b/web/src/router.js @@ -68,6 +68,12 @@ const routes = [ meta: {title: '思维导图'}, component: () => import('@/views/MarkMap.vue'), }, + { + name: 'dalle', + path: '/dalle', + meta: {title: 'DALLE-3'}, + component: () => import('@/views/Dalle.vue'), + }, ] }, { diff --git a/web/src/views/Dalle.vue b/web/src/views/Dalle.vue new file mode 100644 index 00000000..c81a9cc9 --- /dev/null +++ b/web/src/views/Dalle.vue @@ -0,0 +1,754 @@ + + + + + diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 4f3f0534..e8560366 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -725,10 +725,18 @@ const connect = () => { _socket.addEventListener('message', event => { if (event.data instanceof Blob) { - fetchRunningJobs() - isOver.value = false - page.value = 1 - fetchFinishJobs(page.value) + const reader = new FileReader(); + reader.readAsText(event.data, "UTF-8") + reader.onload = () => { + const message = String(reader.result) + if (message === "FINISH") { + page.value = 1 + fetchFinishJobs(page.value) + isOver.value = false + } else { + fetchRunningJobs() + } + } } }); diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index c81a9cc9..78484b30 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -568,10 +568,18 @@ const connect = () => { _socket.addEventListener('message', event => { if (event.data instanceof Blob) { - fetchRunningJobs() - isOver.value = false - page.value = 1 - fetchFinishJobs(page.value) + const reader = new FileReader(); + reader.readAsText(event.data, "UTF-8") + reader.onload = () => { + const message = String(reader.result) + if (message === "FINISH") { + page.value = 1 + fetchFinishJobs(page.value) + isOver.value = false + } else { + fetchRunningJobs() + } + } } }); diff --git a/web/src/views/MarkMap.vue b/web/src/views/MarkMap.vue index 76c7706d..3d7ecc55 100644 --- a/web/src/views/MarkMap.vue +++ b/web/src/views/MarkMap.vue @@ -94,7 +94,7 @@ import {loadCSS, loadJS} from 'markmap-common'; import {Transformer} from 'markmap-lib'; import {checkSession} from "@/action/session"; import {httpGet} from "@/utils/http"; -import {ElMessage, ElMessageBox} from "element-plus"; +import {ElMessage} from "element-plus"; const leftBoxHeight = ref(window.innerHeight - 105) const rightBoxHeight = ref(window.innerHeight - 85)