From abf4f061c11048534ae978f3b93f814f735a3070 Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 18 Dec 2023 16:34:33 +0800 Subject: [PATCH] opt: make sure the Upscale and Variation task is assign to the same mj service with Image task --- api/core/types/task.go | 1 + api/handler/mj_handler.go | 34 +++++++++++++++++++++++++++++++++- api/service/mj/bot.go | 9 ++++++--- api/service/mj/client.go | 2 +- api/service/mj/service.go | 18 ++++++++++++++++-- api/service/mj/types.go | 1 + api/store/model/mj_job.go | 1 + api/store/vo/mj_job.go | 1 + database/update-v3.2.3.sql | 2 ++ web/src/views/ImageMj.vue | 6 ++++-- web/src/views/ImageSd.vue | 4 ++-- 11 files changed, 68 insertions(+), 11 deletions(-) create mode 100644 database/update-v3.2.3.sql diff --git a/api/core/types/task.go b/api/core/types/task.go index 8173c26a..f8a3b48b 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -16,6 +16,7 @@ const ( // MjTask MidJourney 任务 type MjTask struct { Id int `json:"id"` + ChannelId string `json:"channel_id"` SessionId string `json:"session_id"` Type TaskType `json:"type"` UserId int `json:"user_id"` diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 2adbf7df..4c43c2fe 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -147,8 +147,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { } type reqVo struct { - Src string `json:"src"` + TaskId string `json:"task_id"` Index int `json:"index"` + ChannelId string `json:"channel_id"` MessageId string `json:"message_id"` MessageHash string `json:"message_hash"` SessionId string `json:"session_id"` @@ -173,12 +174,27 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { idValue, _ := c.Get(types.LoginUserID) jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + job := model.MidJourneyJob{ + Type: types.TaskUpscale.String(), + ReferenceId: data.MessageId, + UserId: userId, + TaskId: data.TaskId, + Progress: 0, + Prompt: data.Prompt, + CreatedAt: time.Now(), + } + if res := h.db.Create(&job); res.Error != nil { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + h.pool.PushTask(types.MjTask{ Id: jobId, SessionId: data.SessionId, Type: types.TaskUpscale, Prompt: data.Prompt, UserId: userId, + ChannelId: data.ChannelId, Index: data.Index, MessageId: data.MessageId, MessageHash: data.MessageHash, @@ -201,6 +217,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { idValue, _ := c.Get(types.LoginUserID) jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + + job := model.MidJourneyJob{ + Type: types.TaskVariation.String(), + ReferenceId: data.MessageId, + UserId: userId, + TaskId: data.TaskId, + Progress: 0, + Prompt: data.Prompt, + CreatedAt: time.Now(), + } + if res := h.db.Create(&job); res.Error != nil { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + h.pool.PushTask(types.MjTask{ Id: jobId, SessionId: data.SessionId, @@ -208,6 +239,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { Prompt: data.Prompt, UserId: userId, Index: data.Index, + ChannelId: data.ChannelId, MessageId: data.MessageId, MessageHash: data.MessageHash, }) diff --git a/api/service/mj/bot.go b/api/service/mj/bot.go index e4412b83..2d78b65c 100644 --- a/api/service/mj/bot.go +++ b/api/service/mj/bot.go @@ -101,6 +101,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") { // parse content req := CBReq{ + ChannelId: m.ChannelID, MessageId: m.ID, ReferenceId: referenceId, Prompt: extractPrompt(m.Content), @@ -111,7 +112,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { return } - b.addAttachment(m.ID, referenceId, m.Content, m.Attachments) + b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments) } func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { @@ -132,6 +133,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { } if strings.Contains(m.Content, "(Stopped)") { req := CBReq{ + ChannelId: m.ChannelID, MessageId: m.ID, ReferenceId: referenceId, Prompt: extractPrompt(m.Content), @@ -142,11 +144,11 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { return } - b.addAttachment(m.ID, referenceId, m.Content, m.Attachments) + b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments) } -func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) { +func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) { progress := extractProgress(content) var status TaskStatus if progress == 100 { @@ -168,6 +170,7 @@ func (b *Bot) addAttachment(messageId string, referenceId string, content string Hash: extractHashFromFilename(attachment.Filename), } req := CBReq{ + ChannelId: channelId, MessageId: messageId, ReferenceId: referenceId, Image: image, diff --git a/api/service/mj/client.go b/api/service/mj/client.go index 6078b708..7eb15bb4 100644 --- a/api/service/mj/client.go +++ b/api/service/mj/client.go @@ -21,7 +21,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client { if proxy != "" { client.SetProxyURL(proxy) } - logger.Info(proxy) + logger.Info(config) return &Client{client: client, config: config} } diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 754966e5..2032f04f 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -54,6 +54,15 @@ func (s *Service) Run() { err := s.taskQueue.LPop(&task) if err != nil { logger.Errorf("taking task with error: %v", err) + s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) + continue + } + + // if it's reference message, check if it's this channel's message + if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId { + s.taskQueue.RPush(task) + s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) + time.Sleep(time.Second) continue } @@ -74,7 +83,6 @@ func (s *Service) Run() { logger.Error("绘画任务执行失败:", err) // update the task progress s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) - atomic.AddInt32(&s.handledTaskNum, -1) continue } @@ -113,11 +121,17 @@ func (s *Service) Notify(data CBReq) { return } - res = s.db.Where("task_id = ?", split[0]).First(&job) + tx := s.db.Where("task_id = ? AND progress < 100", split[0]).Session(&gorm.Session{}).Order("id ASC") + if data.ReferenceId != "" { + tx = tx.Where("reference_id = ?", data.ReferenceId) + } + res = tx.First(&job) if res.Error != nil { logger.Warn("非法任务:", res.Error) return } + + job.ChannelId = data.ChannelId job.MessageId = data.MessageId job.ReferenceId = data.ReferenceId job.Progress = data.Progress diff --git a/api/service/mj/types.go b/api/service/mj/types.go index 7de8c0c1..ec367210 100644 --- a/api/service/mj/types.go +++ b/api/service/mj/types.go @@ -24,6 +24,7 @@ type InteractionsResult struct { } type CBReq struct { + ChannelId string `json:"channel_id"` MessageId string `json:"message_id"` ReferenceId string `json:"reference_id"` Image Image `json:"image"` diff --git a/api/store/model/mj_job.go b/api/store/model/mj_job.go index 3ffd99d8..488c3b20 100644 --- a/api/store/model/mj_job.go +++ b/api/store/model/mj_job.go @@ -7,6 +7,7 @@ type MidJourneyJob struct { Type string UserId int TaskId string + ChannelId string MessageId string ReferenceId string ImgURL string diff --git a/api/store/vo/mj_job.go b/api/store/vo/mj_job.go index cbb9bbc6..bfc236ce 100644 --- a/api/store/vo/mj_job.go +++ b/api/store/vo/mj_job.go @@ -6,6 +6,7 @@ type MidJourneyJob struct { Id uint `json:"id"` Type string `json:"type"` UserId int `json:"user_id"` + ChannelId string `json:"channel_id"` TaskId string `json:"task_id"` MessageId string `json:"message_id"` ReferenceId string `json:"reference_id"` diff --git a/database/update-v3.2.3.sql b/database/update-v3.2.3.sql new file mode 100644 index 00000000..14ada2ef --- /dev/null +++ b/database/update-v3.2.3.sql @@ -0,0 +1,2 @@ +ALTER TABLE `chatgpt_mj_jobs` ADD `channel_id` CHAR(40) NULL DEFAULT NULL COMMENT '频道ID' AFTER `message_id`; +ALTER TABLE `chatgpt_mj_jobs` DROP INDEX `task_id`; \ No newline at end of file diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index f68ca358..64533adb 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -581,7 +581,7 @@ const fetchRunningJobs = (userId) => { } runningJobs.value = _jobs - setTimeout(() => fetchRunningJobs(userId), 5000) + setTimeout(() => fetchRunningJobs(userId), 3000) }).catch(e => { ElMessage.error("获取任务失败:" + e.message) @@ -592,7 +592,7 @@ const fetchFinishJobs = (userId) => { // 获取已完成的任务 httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => { finishedJobs.value = res.data - setTimeout(() => fetchFinishJobs(userId), 5000) + setTimeout(() => fetchFinishJobs(userId), 3000) }).catch(e => { ElMessage.error("获取任务失败:" + e.message) }) @@ -660,6 +660,8 @@ const variation = (index, item) => { const send = (url, index, item) => { httpPost(url, { index: index, + task_id: item.task_id, + channel_id: item.channel_id, message_id: item.message_id, message_hash: item.hash, session_id: getSessionId(), diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index 761f01bb..f3f689a6 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -598,7 +598,7 @@ onMounted(() => { } runningJobs.value = _jobs - setTimeout(() => fetchRunningJobs(userId), 5000) + setTimeout(() => fetchRunningJobs(userId), 3000) }).catch(e => { ElMessage.error("获取任务失败:" + e.message) }) @@ -608,7 +608,7 @@ onMounted(() => { const fetchFinishJobs = (userId) => { httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => { finishedJobs.value = res.data - setTimeout(() => fetchFinishJobs(userId), 5000) + setTimeout(() => fetchFinishJobs(userId), 3000) }).catch(e => { ElMessage.error("获取任务失败:" + e.message) })