mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	opt: make sure the Upscale and Variation task is assign to the same mj service with Image task
This commit is contained in:
		@@ -16,6 +16,7 @@ const (
 | 
				
			|||||||
// MjTask MidJourney 任务
 | 
					// MjTask MidJourney 任务
 | 
				
			||||||
type MjTask struct {
 | 
					type MjTask struct {
 | 
				
			||||||
	Id          int      `json:"id"`
 | 
						Id          int      `json:"id"`
 | 
				
			||||||
 | 
						ChannelId   string   `json:"channel_id"`
 | 
				
			||||||
	SessionId   string   `json:"session_id"`
 | 
						SessionId   string   `json:"session_id"`
 | 
				
			||||||
	Type        TaskType `json:"type"`
 | 
						Type        TaskType `json:"type"`
 | 
				
			||||||
	UserId      int      `json:"user_id"`
 | 
						UserId      int      `json:"user_id"`
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -147,8 +147,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type reqVo struct {
 | 
					type reqVo struct {
 | 
				
			||||||
	Src         string `json:"src"`
 | 
						TaskId      string `json:"task_id"`
 | 
				
			||||||
	Index       int    `json:"index"`
 | 
						Index       int    `json:"index"`
 | 
				
			||||||
 | 
						ChannelId   string `json:"channel_id"`
 | 
				
			||||||
	MessageId   string `json:"message_id"`
 | 
						MessageId   string `json:"message_id"`
 | 
				
			||||||
	MessageHash string `json:"message_hash"`
 | 
						MessageHash string `json:"message_hash"`
 | 
				
			||||||
	SessionId   string `json:"session_id"`
 | 
						SessionId   string `json:"session_id"`
 | 
				
			||||||
@@ -173,12 +174,27 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
				
			|||||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
						idValue, _ := c.Get(types.LoginUserID)
 | 
				
			||||||
	jobId := 0
 | 
						jobId := 0
 | 
				
			||||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
						userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
				
			||||||
 | 
						job := model.MidJourneyJob{
 | 
				
			||||||
 | 
							Type:        types.TaskUpscale.String(),
 | 
				
			||||||
 | 
							ReferenceId: data.MessageId,
 | 
				
			||||||
 | 
							UserId:      userId,
 | 
				
			||||||
 | 
							TaskId:      data.TaskId,
 | 
				
			||||||
 | 
							Progress:    0,
 | 
				
			||||||
 | 
							Prompt:      data.Prompt,
 | 
				
			||||||
 | 
							CreatedAt:   time.Now(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if res := h.db.Create(&job); res.Error != nil {
 | 
				
			||||||
 | 
							resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	h.pool.PushTask(types.MjTask{
 | 
						h.pool.PushTask(types.MjTask{
 | 
				
			||||||
		Id:          jobId,
 | 
							Id:          jobId,
 | 
				
			||||||
		SessionId:   data.SessionId,
 | 
							SessionId:   data.SessionId,
 | 
				
			||||||
		Type:        types.TaskUpscale,
 | 
							Type:        types.TaskUpscale,
 | 
				
			||||||
		Prompt:      data.Prompt,
 | 
							Prompt:      data.Prompt,
 | 
				
			||||||
		UserId:      userId,
 | 
							UserId:      userId,
 | 
				
			||||||
 | 
							ChannelId:   data.ChannelId,
 | 
				
			||||||
		Index:       data.Index,
 | 
							Index:       data.Index,
 | 
				
			||||||
		MessageId:   data.MessageId,
 | 
							MessageId:   data.MessageId,
 | 
				
			||||||
		MessageHash: data.MessageHash,
 | 
							MessageHash: data.MessageHash,
 | 
				
			||||||
@@ -201,6 +217,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
				
			|||||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
						idValue, _ := c.Get(types.LoginUserID)
 | 
				
			||||||
	jobId := 0
 | 
						jobId := 0
 | 
				
			||||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
						userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						job := model.MidJourneyJob{
 | 
				
			||||||
 | 
							Type:        types.TaskVariation.String(),
 | 
				
			||||||
 | 
							ReferenceId: data.MessageId,
 | 
				
			||||||
 | 
							UserId:      userId,
 | 
				
			||||||
 | 
							TaskId:      data.TaskId,
 | 
				
			||||||
 | 
							Progress:    0,
 | 
				
			||||||
 | 
							Prompt:      data.Prompt,
 | 
				
			||||||
 | 
							CreatedAt:   time.Now(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if res := h.db.Create(&job); res.Error != nil {
 | 
				
			||||||
 | 
							resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	h.pool.PushTask(types.MjTask{
 | 
						h.pool.PushTask(types.MjTask{
 | 
				
			||||||
		Id:          jobId,
 | 
							Id:          jobId,
 | 
				
			||||||
		SessionId:   data.SessionId,
 | 
							SessionId:   data.SessionId,
 | 
				
			||||||
@@ -208,6 +239,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
				
			|||||||
		Prompt:      data.Prompt,
 | 
							Prompt:      data.Prompt,
 | 
				
			||||||
		UserId:      userId,
 | 
							UserId:      userId,
 | 
				
			||||||
		Index:       data.Index,
 | 
							Index:       data.Index,
 | 
				
			||||||
 | 
							ChannelId:   data.ChannelId,
 | 
				
			||||||
		MessageId:   data.MessageId,
 | 
							MessageId:   data.MessageId,
 | 
				
			||||||
		MessageHash: data.MessageHash,
 | 
							MessageHash: data.MessageHash,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -101,6 +101,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
 | 
				
			|||||||
	if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
 | 
						if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
 | 
				
			||||||
		// parse content
 | 
							// parse content
 | 
				
			||||||
		req := CBReq{
 | 
							req := CBReq{
 | 
				
			||||||
 | 
								ChannelId:   m.ChannelID,
 | 
				
			||||||
			MessageId:   m.ID,
 | 
								MessageId:   m.ID,
 | 
				
			||||||
			ReferenceId: referenceId,
 | 
								ReferenceId: referenceId,
 | 
				
			||||||
			Prompt:      extractPrompt(m.Content),
 | 
								Prompt:      extractPrompt(m.Content),
 | 
				
			||||||
@@ -111,7 +112,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
 | 
						b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
					func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
				
			||||||
@@ -132,6 +133,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	if strings.Contains(m.Content, "(Stopped)") {
 | 
						if strings.Contains(m.Content, "(Stopped)") {
 | 
				
			||||||
		req := CBReq{
 | 
							req := CBReq{
 | 
				
			||||||
 | 
								ChannelId:   m.ChannelID,
 | 
				
			||||||
			MessageId:   m.ID,
 | 
								MessageId:   m.ID,
 | 
				
			||||||
			ReferenceId: referenceId,
 | 
								ReferenceId: referenceId,
 | 
				
			||||||
			Prompt:      extractPrompt(m.Content),
 | 
								Prompt:      extractPrompt(m.Content),
 | 
				
			||||||
@@ -142,11 +144,11 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
 | 
						b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
 | 
					func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
 | 
				
			||||||
	progress := extractProgress(content)
 | 
						progress := extractProgress(content)
 | 
				
			||||||
	var status TaskStatus
 | 
						var status TaskStatus
 | 
				
			||||||
	if progress == 100 {
 | 
						if progress == 100 {
 | 
				
			||||||
@@ -168,6 +170,7 @@ func (b *Bot) addAttachment(messageId string, referenceId string, content string
 | 
				
			|||||||
			Hash:     extractHashFromFilename(attachment.Filename),
 | 
								Hash:     extractHashFromFilename(attachment.Filename),
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		req := CBReq{
 | 
							req := CBReq{
 | 
				
			||||||
 | 
								ChannelId:   channelId,
 | 
				
			||||||
			MessageId:   messageId,
 | 
								MessageId:   messageId,
 | 
				
			||||||
			ReferenceId: referenceId,
 | 
								ReferenceId: referenceId,
 | 
				
			||||||
			Image:       image,
 | 
								Image:       image,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,7 +21,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
 | 
				
			|||||||
	if proxy != "" {
 | 
						if proxy != "" {
 | 
				
			||||||
		client.SetProxyURL(proxy)
 | 
							client.SetProxyURL(proxy)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	logger.Info(proxy)
 | 
						logger.Info(config)
 | 
				
			||||||
	return &Client{client: client, config: config}
 | 
						return &Client{client: client, config: config}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -54,6 +54,15 @@ func (s *Service) Run() {
 | 
				
			|||||||
		err := s.taskQueue.LPop(&task)
 | 
							err := s.taskQueue.LPop(&task)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			logger.Errorf("taking task with error: %v", err)
 | 
								logger.Errorf("taking task with error: %v", err)
 | 
				
			||||||
 | 
								s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// if it's reference message, check if it's this channel's  message
 | 
				
			||||||
 | 
							if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
 | 
				
			||||||
 | 
								s.taskQueue.RPush(task)
 | 
				
			||||||
 | 
								s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
				
			||||||
 | 
								time.Sleep(time.Second)
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -74,7 +83,6 @@ func (s *Service) Run() {
 | 
				
			|||||||
			logger.Error("绘画任务执行失败:", err)
 | 
								logger.Error("绘画任务执行失败:", err)
 | 
				
			||||||
			// update the task progress
 | 
								// update the task progress
 | 
				
			||||||
			s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
								s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
				
			||||||
			atomic.AddInt32(&s.handledTaskNum, -1)
 | 
					 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -113,11 +121,17 @@ func (s *Service) Notify(data CBReq) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	res = s.db.Where("task_id = ?", split[0]).First(&job)
 | 
						tx := s.db.Where("task_id = ? AND progress < 100", split[0]).Session(&gorm.Session{}).Order("id ASC")
 | 
				
			||||||
 | 
						if data.ReferenceId != "" {
 | 
				
			||||||
 | 
							tx = tx.Where("reference_id = ?", data.ReferenceId)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						res = tx.First(&job)
 | 
				
			||||||
	if res.Error != nil {
 | 
						if res.Error != nil {
 | 
				
			||||||
		logger.Warn("非法任务:", res.Error)
 | 
							logger.Warn("非法任务:", res.Error)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						job.ChannelId = data.ChannelId
 | 
				
			||||||
	job.MessageId = data.MessageId
 | 
						job.MessageId = data.MessageId
 | 
				
			||||||
	job.ReferenceId = data.ReferenceId
 | 
						job.ReferenceId = data.ReferenceId
 | 
				
			||||||
	job.Progress = data.Progress
 | 
						job.Progress = data.Progress
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -24,6 +24,7 @@ type InteractionsResult struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type CBReq struct {
 | 
					type CBReq struct {
 | 
				
			||||||
 | 
						ChannelId   string     `json:"channel_id"`
 | 
				
			||||||
	MessageId   string     `json:"message_id"`
 | 
						MessageId   string     `json:"message_id"`
 | 
				
			||||||
	ReferenceId string     `json:"reference_id"`
 | 
						ReferenceId string     `json:"reference_id"`
 | 
				
			||||||
	Image       Image      `json:"image"`
 | 
						Image       Image      `json:"image"`
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,6 +7,7 @@ type MidJourneyJob struct {
 | 
				
			|||||||
	Type        string
 | 
						Type        string
 | 
				
			||||||
	UserId      int
 | 
						UserId      int
 | 
				
			||||||
	TaskId      string
 | 
						TaskId      string
 | 
				
			||||||
 | 
						ChannelId   string
 | 
				
			||||||
	MessageId   string
 | 
						MessageId   string
 | 
				
			||||||
	ReferenceId string
 | 
						ReferenceId string
 | 
				
			||||||
	ImgURL      string
 | 
						ImgURL      string
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ type MidJourneyJob struct {
 | 
				
			|||||||
	Id          uint      `json:"id"`
 | 
						Id          uint      `json:"id"`
 | 
				
			||||||
	Type        string    `json:"type"`
 | 
						Type        string    `json:"type"`
 | 
				
			||||||
	UserId      int       `json:"user_id"`
 | 
						UserId      int       `json:"user_id"`
 | 
				
			||||||
 | 
						ChannelId   string    `json:"channel_id"`
 | 
				
			||||||
	TaskId      string    `json:"task_id"`
 | 
						TaskId      string    `json:"task_id"`
 | 
				
			||||||
	MessageId   string    `json:"message_id"`
 | 
						MessageId   string    `json:"message_id"`
 | 
				
			||||||
	ReferenceId string    `json:"reference_id"`
 | 
						ReferenceId string    `json:"reference_id"`
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								database/update-v3.2.3.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								database/update-v3.2.3.sql
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
				
			|||||||
 | 
					ALTER TABLE `chatgpt_mj_jobs` ADD `channel_id` CHAR(40) NULL DEFAULT NULL COMMENT '频道ID' AFTER `message_id`;
 | 
				
			||||||
 | 
					ALTER TABLE `chatgpt_mj_jobs` DROP INDEX `task_id`;
 | 
				
			||||||
@@ -581,7 +581,7 @@ const fetchRunningJobs = (userId) => {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    runningJobs.value = _jobs
 | 
					    runningJobs.value = _jobs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    setTimeout(() => fetchRunningJobs(userId), 5000)
 | 
					    setTimeout(() => fetchRunningJobs(userId), 3000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  }).catch(e => {
 | 
					  }).catch(e => {
 | 
				
			||||||
    ElMessage.error("获取任务失败:" + e.message)
 | 
					    ElMessage.error("获取任务失败:" + e.message)
 | 
				
			||||||
@@ -592,7 +592,7 @@ const fetchFinishJobs = (userId) => {
 | 
				
			|||||||
  // 获取已完成的任务
 | 
					  // 获取已完成的任务
 | 
				
			||||||
  httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
 | 
					  httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
 | 
				
			||||||
    finishedJobs.value = res.data
 | 
					    finishedJobs.value = res.data
 | 
				
			||||||
    setTimeout(() => fetchFinishJobs(userId), 5000)
 | 
					    setTimeout(() => fetchFinishJobs(userId), 3000)
 | 
				
			||||||
  }).catch(e => {
 | 
					  }).catch(e => {
 | 
				
			||||||
    ElMessage.error("获取任务失败:" + e.message)
 | 
					    ElMessage.error("获取任务失败:" + e.message)
 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
@@ -660,6 +660,8 @@ const variation = (index, item) => {
 | 
				
			|||||||
const send = (url, index, item) => {
 | 
					const send = (url, index, item) => {
 | 
				
			||||||
  httpPost(url, {
 | 
					  httpPost(url, {
 | 
				
			||||||
    index: index,
 | 
					    index: index,
 | 
				
			||||||
 | 
					    task_id: item.task_id,
 | 
				
			||||||
 | 
					    channel_id: item.channel_id,
 | 
				
			||||||
    message_id: item.message_id,
 | 
					    message_id: item.message_id,
 | 
				
			||||||
    message_hash: item.hash,
 | 
					    message_hash: item.hash,
 | 
				
			||||||
    session_id: getSessionId(),
 | 
					    session_id: getSessionId(),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -598,7 +598,7 @@ onMounted(() => {
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
      runningJobs.value = _jobs
 | 
					      runningJobs.value = _jobs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      setTimeout(() => fetchRunningJobs(userId), 5000)
 | 
					      setTimeout(() => fetchRunningJobs(userId), 3000)
 | 
				
			||||||
    }).catch(e => {
 | 
					    }).catch(e => {
 | 
				
			||||||
      ElMessage.error("获取任务失败:" + e.message)
 | 
					      ElMessage.error("获取任务失败:" + e.message)
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
@@ -608,7 +608,7 @@ onMounted(() => {
 | 
				
			|||||||
  const fetchFinishJobs = (userId) => {
 | 
					  const fetchFinishJobs = (userId) => {
 | 
				
			||||||
    httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
 | 
					    httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
 | 
				
			||||||
      finishedJobs.value = res.data
 | 
					      finishedJobs.value = res.data
 | 
				
			||||||
      setTimeout(() => fetchFinishJobs(userId), 5000)
 | 
					      setTimeout(() => fetchFinishJobs(userId), 3000)
 | 
				
			||||||
    }).catch(e => {
 | 
					    }).catch(e => {
 | 
				
			||||||
      ElMessage.error("获取任务失败:" + e.message)
 | 
					      ElMessage.error("获取任务失败:" + e.message)
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user