opt: make sure the Upscale and Variation task is assign to the same mj service with Image task

This commit is contained in:
RockYang 2023-12-18 16:34:33 +08:00
parent 245cd3ee1a
commit abf4f061c1
11 changed files with 68 additions and 11 deletions

View File

@ -16,6 +16,7 @@ const (
// MjTask MidJourney 任务 // MjTask MidJourney 任务
type MjTask struct { type MjTask struct {
Id int `json:"id"` Id int `json:"id"`
ChannelId string `json:"channel_id"`
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`

View File

@ -147,8 +147,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
} }
type reqVo struct { type reqVo struct {
Src string `json:"src"` TaskId string `json:"task_id"`
Index int `json:"index"` Index int `json:"index"`
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"` MessageHash string `json:"message_hash"`
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
@ -173,12 +174,27 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
jobId := 0 jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
ReferenceId: data.MessageId,
UserId: userId,
TaskId: data.TaskId,
Progress: 0,
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{ h.pool.PushTask(types.MjTask{
Id: jobId, Id: jobId,
SessionId: data.SessionId, SessionId: data.SessionId,
Type: types.TaskUpscale, Type: types.TaskUpscale,
Prompt: data.Prompt, Prompt: data.Prompt,
UserId: userId, UserId: userId,
ChannelId: data.ChannelId,
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
@ -201,6 +217,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
jobId := 0 jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
ReferenceId: data.MessageId,
UserId: userId,
TaskId: data.TaskId,
Progress: 0,
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.pool.PushTask(types.MjTask{ h.pool.PushTask(types.MjTask{
Id: jobId, Id: jobId,
SessionId: data.SessionId, SessionId: data.SessionId,
@ -208,6 +239,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
Prompt: data.Prompt, Prompt: data.Prompt,
UserId: userId, UserId: userId,
Index: data.Index, Index: data.Index,
ChannelId: data.ChannelId,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
}) })

View File

@ -101,6 +101,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") { if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
// parse content // parse content
req := CBReq{ req := CBReq{
ChannelId: m.ChannelID,
MessageId: m.ID, MessageId: m.ID,
ReferenceId: referenceId, ReferenceId: referenceId,
Prompt: extractPrompt(m.Content), Prompt: extractPrompt(m.Content),
@ -111,7 +112,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
return return
} }
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments) b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
} }
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
@ -132,6 +133,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
} }
if strings.Contains(m.Content, "(Stopped)") { if strings.Contains(m.Content, "(Stopped)") {
req := CBReq{ req := CBReq{
ChannelId: m.ChannelID,
MessageId: m.ID, MessageId: m.ID,
ReferenceId: referenceId, ReferenceId: referenceId,
Prompt: extractPrompt(m.Content), Prompt: extractPrompt(m.Content),
@ -142,11 +144,11 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
return return
} }
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments) b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
} }
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) { func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
progress := extractProgress(content) progress := extractProgress(content)
var status TaskStatus var status TaskStatus
if progress == 100 { if progress == 100 {
@ -168,6 +170,7 @@ func (b *Bot) addAttachment(messageId string, referenceId string, content string
Hash: extractHashFromFilename(attachment.Filename), Hash: extractHashFromFilename(attachment.Filename),
} }
req := CBReq{ req := CBReq{
ChannelId: channelId,
MessageId: messageId, MessageId: messageId,
ReferenceId: referenceId, ReferenceId: referenceId,
Image: image, Image: image,

View File

@ -21,7 +21,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
if proxy != "" { if proxy != "" {
client.SetProxyURL(proxy) client.SetProxyURL(proxy)
} }
logger.Info(proxy) logger.Info(config)
return &Client{client: client, config: config} return &Client{client: client, config: config}
} }

View File

@ -54,6 +54,15 @@ func (s *Service) Run() {
err := s.taskQueue.LPop(&task) err := s.taskQueue.LPop(&task)
if err != nil { if err != nil {
logger.Errorf("taking task with error: %v", err) logger.Errorf("taking task with error: %v", err)
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
continue
}
// if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
s.taskQueue.RPush(task)
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
time.Sleep(time.Second)
continue continue
} }
@ -74,7 +83,6 @@ func (s *Service) Run() {
logger.Error("绘画任务执行失败:", err) logger.Error("绘画任务执行失败:", err)
// update the task progress // update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
atomic.AddInt32(&s.handledTaskNum, -1)
continue continue
} }
@ -113,11 +121,17 @@ func (s *Service) Notify(data CBReq) {
return return
} }
res = s.db.Where("task_id = ?", split[0]).First(&job) tx := s.db.Where("task_id = ? AND progress < 100", split[0]).Session(&gorm.Session{}).Order("id ASC")
if data.ReferenceId != "" {
tx = tx.Where("reference_id = ?", data.ReferenceId)
}
res = tx.First(&job)
if res.Error != nil { if res.Error != nil {
logger.Warn("非法任务:", res.Error) logger.Warn("非法任务:", res.Error)
return return
} }
job.ChannelId = data.ChannelId
job.MessageId = data.MessageId job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId job.ReferenceId = data.ReferenceId
job.Progress = data.Progress job.Progress = data.Progress

View File

@ -24,6 +24,7 @@ type InteractionsResult struct {
} }
type CBReq struct { type CBReq struct {
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"` ReferenceId string `json:"reference_id"`
Image Image `json:"image"` Image Image `json:"image"`

View File

@ -7,6 +7,7 @@ type MidJourneyJob struct {
Type string Type string
UserId int UserId int
TaskId string TaskId string
ChannelId string
MessageId string MessageId string
ReferenceId string ReferenceId string
ImgURL string ImgURL string

View File

@ -6,6 +6,7 @@ type MidJourneyJob struct {
Id uint `json:"id"` Id uint `json:"id"`
Type string `json:"type"` Type string `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
ChannelId string `json:"channel_id"`
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"` ReferenceId string `json:"reference_id"`

View File

@ -0,0 +1,2 @@
ALTER TABLE `chatgpt_mj_jobs` ADD `channel_id` CHAR(40) NULL DEFAULT NULL COMMENT '频道ID' AFTER `message_id`;
ALTER TABLE `chatgpt_mj_jobs` DROP INDEX `task_id`;

View File

@ -581,7 +581,7 @@ const fetchRunningJobs = (userId) => {
} }
runningJobs.value = _jobs runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 5000) setTimeout(() => fetchRunningJobs(userId), 3000)
}).catch(e => { }).catch(e => {
ElMessage.error("获取任务失败:" + e.message) ElMessage.error("获取任务失败:" + e.message)
@ -592,7 +592,7 @@ const fetchFinishJobs = (userId) => {
// //
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => { httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
finishedJobs.value = res.data finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 5000) setTimeout(() => fetchFinishJobs(userId), 3000)
}).catch(e => { }).catch(e => {
ElMessage.error("获取任务失败:" + e.message) ElMessage.error("获取任务失败:" + e.message)
}) })
@ -660,6 +660,8 @@ const variation = (index, item) => {
const send = (url, index, item) => { const send = (url, index, item) => {
httpPost(url, { httpPost(url, {
index: index, index: index,
task_id: item.task_id,
channel_id: item.channel_id,
message_id: item.message_id, message_id: item.message_id,
message_hash: item.hash, message_hash: item.hash,
session_id: getSessionId(), session_id: getSessionId(),

View File

@ -598,7 +598,7 @@ onMounted(() => {
} }
runningJobs.value = _jobs runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 5000) setTimeout(() => fetchRunningJobs(userId), 3000)
}).catch(e => { }).catch(e => {
ElMessage.error("获取任务失败:" + e.message) ElMessage.error("获取任务失败:" + e.message)
}) })
@ -608,7 +608,7 @@ onMounted(() => {
const fetchFinishJobs = (userId) => { const fetchFinishJobs = (userId) => {
httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => { httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
finishedJobs.value = res.data finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 5000) setTimeout(() => fetchFinishJobs(userId), 3000)
}).catch(e => { }).catch(e => {
ElMessage.error("获取任务失败:" + e.message) ElMessage.error("获取任务失败:" + e.message)
}) })