opt: make sure the Upscale and Variation task is assign to the same mj service with Image task

This commit is contained in:
RockYang
2023-12-18 16:34:33 +08:00
parent 245cd3ee1a
commit abf4f061c1
11 changed files with 68 additions and 11 deletions

View File

@@ -101,6 +101,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
// parse content
req := CBReq{
ChannelId: m.ChannelID,
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
@@ -111,7 +112,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
return
}
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
@@ -132,6 +133,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
}
if strings.Contains(m.Content, "(Stopped)") {
req := CBReq{
ChannelId: m.ChannelID,
MessageId: m.ID,
ReferenceId: referenceId,
Prompt: extractPrompt(m.Content),
@@ -142,11 +144,11 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
return
}
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
}
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
progress := extractProgress(content)
var status TaskStatus
if progress == 100 {
@@ -168,6 +170,7 @@ func (b *Bot) addAttachment(messageId string, referenceId string, content string
Hash: extractHashFromFilename(attachment.Filename),
}
req := CBReq{
ChannelId: channelId,
MessageId: messageId,
ReferenceId: referenceId,
Image: image,

View File

@@ -21,7 +21,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
if proxy != "" {
client.SetProxyURL(proxy)
}
logger.Info(proxy)
logger.Info(config)
return &Client{client: client, config: config}
}

View File

@@ -54,6 +54,15 @@ func (s *Service) Run() {
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
continue
}
// if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
s.taskQueue.RPush(task)
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
time.Sleep(time.Second)
continue
}
@@ -74,7 +83,6 @@ func (s *Service) Run() {
logger.Error("绘画任务执行失败:", err)
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
atomic.AddInt32(&s.handledTaskNum, -1)
continue
}
@@ -113,11 +121,17 @@ func (s *Service) Notify(data CBReq) {
return
}
res = s.db.Where("task_id = ?", split[0]).First(&job)
tx := s.db.Where("task_id = ? AND progress < 100", split[0]).Session(&gorm.Session{}).Order("id ASC")
if data.ReferenceId != "" {
tx = tx.Where("reference_id = ?", data.ReferenceId)
}
res = tx.First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
job.ChannelId = data.ChannelId
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Progress = data.Progress

View File

@@ -24,6 +24,7 @@ type InteractionsResult struct {
}
type CBReq struct {
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"`
Image Image `json:"image"`