mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	opt: make sure the Upscale and Variation task is assign to the same mj service with Image task
This commit is contained in:
		@@ -16,6 +16,7 @@ const (
 | 
			
		||||
// MjTask MidJourney 任务
 | 
			
		||||
type MjTask struct {
 | 
			
		||||
	Id          int      `json:"id"`
 | 
			
		||||
	ChannelId   string   `json:"channel_id"`
 | 
			
		||||
	SessionId   string   `json:"session_id"`
 | 
			
		||||
	Type        TaskType `json:"type"`
 | 
			
		||||
	UserId      int      `json:"user_id"`
 | 
			
		||||
 
 | 
			
		||||
@@ -147,8 +147,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type reqVo struct {
 | 
			
		||||
	Src         string `json:"src"`
 | 
			
		||||
	TaskId      string `json:"task_id"`
 | 
			
		||||
	Index       int    `json:"index"`
 | 
			
		||||
	ChannelId   string `json:"channel_id"`
 | 
			
		||||
	MessageId   string `json:"message_id"`
 | 
			
		||||
	MessageHash string `json:"message_hash"`
 | 
			
		||||
	SessionId   string `json:"session_id"`
 | 
			
		||||
@@ -173,12 +174,27 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
			
		||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	jobId := 0
 | 
			
		||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
	job := model.MidJourneyJob{
 | 
			
		||||
		Type:        types.TaskUpscale.String(),
 | 
			
		||||
		ReferenceId: data.MessageId,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		TaskId:      data.TaskId,
 | 
			
		||||
		Progress:    0,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		CreatedAt:   time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
	if res := h.db.Create(&job); res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.MjTask{
 | 
			
		||||
		Id:          jobId,
 | 
			
		||||
		SessionId:   data.SessionId,
 | 
			
		||||
		Type:        types.TaskUpscale,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		ChannelId:   data.ChannelId,
 | 
			
		||||
		Index:       data.Index,
 | 
			
		||||
		MessageId:   data.MessageId,
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
@@ -201,6 +217,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
			
		||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	jobId := 0
 | 
			
		||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
 | 
			
		||||
	job := model.MidJourneyJob{
 | 
			
		||||
		Type:        types.TaskVariation.String(),
 | 
			
		||||
		ReferenceId: data.MessageId,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		TaskId:      data.TaskId,
 | 
			
		||||
		Progress:    0,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		CreatedAt:   time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
	if res := h.db.Create(&job); res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.MjTask{
 | 
			
		||||
		Id:          jobId,
 | 
			
		||||
		SessionId:   data.SessionId,
 | 
			
		||||
@@ -208,6 +239,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		Index:       data.Index,
 | 
			
		||||
		ChannelId:   data.ChannelId,
 | 
			
		||||
		MessageId:   data.MessageId,
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
	})
 | 
			
		||||
 
 | 
			
		||||
@@ -101,6 +101,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
 | 
			
		||||
	if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
 | 
			
		||||
		// parse content
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			ChannelId:   m.ChannelID,
 | 
			
		||||
			MessageId:   m.ID,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Prompt:      extractPrompt(m.Content),
 | 
			
		||||
@@ -111,7 +112,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
	b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
			
		||||
@@ -132,6 +133,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(m.Content, "(Stopped)") {
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			ChannelId:   m.ChannelID,
 | 
			
		||||
			MessageId:   m.ID,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Prompt:      extractPrompt(m.Content),
 | 
			
		||||
@@ -142,11 +144,11 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
	b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
 | 
			
		||||
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
 | 
			
		||||
	progress := extractProgress(content)
 | 
			
		||||
	var status TaskStatus
 | 
			
		||||
	if progress == 100 {
 | 
			
		||||
@@ -168,6 +170,7 @@ func (b *Bot) addAttachment(messageId string, referenceId string, content string
 | 
			
		||||
			Hash:     extractHashFromFilename(attachment.Filename),
 | 
			
		||||
		}
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			ChannelId:   channelId,
 | 
			
		||||
			MessageId:   messageId,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Image:       image,
 | 
			
		||||
 
 | 
			
		||||
@@ -21,7 +21,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
 | 
			
		||||
	if proxy != "" {
 | 
			
		||||
		client.SetProxyURL(proxy)
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info(proxy)
 | 
			
		||||
	logger.Info(config)
 | 
			
		||||
	return &Client{client: client, config: config}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -54,6 +54,15 @@ func (s *Service) Run() {
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// if it's reference message, check if it's this channel's  message
 | 
			
		||||
		if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
 | 
			
		||||
			s.taskQueue.RPush(task)
 | 
			
		||||
			s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -74,7 +83,6 @@ func (s *Service) Run() {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err)
 | 
			
		||||
			// update the task progress
 | 
			
		||||
			s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
			
		||||
			atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -113,11 +121,17 @@ func (s *Service) Notify(data CBReq) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res = s.db.Where("task_id = ?", split[0]).First(&job)
 | 
			
		||||
	tx := s.db.Where("task_id = ? AND progress < 100", split[0]).Session(&gorm.Session{}).Order("id ASC")
 | 
			
		||||
	if data.ReferenceId != "" {
 | 
			
		||||
		tx = tx.Where("reference_id = ?", data.ReferenceId)
 | 
			
		||||
	}
 | 
			
		||||
	res = tx.First(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Warn("非法任务:", res.Error)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	job.ChannelId = data.ChannelId
 | 
			
		||||
	job.MessageId = data.MessageId
 | 
			
		||||
	job.ReferenceId = data.ReferenceId
 | 
			
		||||
	job.Progress = data.Progress
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,7 @@ type InteractionsResult struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	ChannelId   string     `json:"channel_id"`
 | 
			
		||||
	MessageId   string     `json:"message_id"`
 | 
			
		||||
	ReferenceId string     `json:"reference_id"`
 | 
			
		||||
	Image       Image      `json:"image"`
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ type MidJourneyJob struct {
 | 
			
		||||
	Type        string
 | 
			
		||||
	UserId      int
 | 
			
		||||
	TaskId      string
 | 
			
		||||
	ChannelId   string
 | 
			
		||||
	MessageId   string
 | 
			
		||||
	ReferenceId string
 | 
			
		||||
	ImgURL      string
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ type MidJourneyJob struct {
 | 
			
		||||
	Id          uint      `json:"id"`
 | 
			
		||||
	Type        string    `json:"type"`
 | 
			
		||||
	UserId      int       `json:"user_id"`
 | 
			
		||||
	ChannelId   string    `json:"channel_id"`
 | 
			
		||||
	TaskId      string    `json:"task_id"`
 | 
			
		||||
	MessageId   string    `json:"message_id"`
 | 
			
		||||
	ReferenceId string    `json:"reference_id"`
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								database/update-v3.2.3.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								database/update-v3.2.3.sql
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
ALTER TABLE `chatgpt_mj_jobs` ADD `channel_id` CHAR(40) NULL DEFAULT NULL COMMENT '频道ID' AFTER `message_id`;
 | 
			
		||||
ALTER TABLE `chatgpt_mj_jobs` DROP INDEX `task_id`;
 | 
			
		||||
@@ -581,7 +581,7 @@ const fetchRunningJobs = (userId) => {
 | 
			
		||||
    }
 | 
			
		||||
    runningJobs.value = _jobs
 | 
			
		||||
 | 
			
		||||
    setTimeout(() => fetchRunningJobs(userId), 5000)
 | 
			
		||||
    setTimeout(() => fetchRunningJobs(userId), 3000)
 | 
			
		||||
 | 
			
		||||
  }).catch(e => {
 | 
			
		||||
    ElMessage.error("获取任务失败:" + e.message)
 | 
			
		||||
@@ -592,7 +592,7 @@ const fetchFinishJobs = (userId) => {
 | 
			
		||||
  // 获取已完成的任务
 | 
			
		||||
  httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
 | 
			
		||||
    finishedJobs.value = res.data
 | 
			
		||||
    setTimeout(() => fetchFinishJobs(userId), 5000)
 | 
			
		||||
    setTimeout(() => fetchFinishJobs(userId), 3000)
 | 
			
		||||
  }).catch(e => {
 | 
			
		||||
    ElMessage.error("获取任务失败:" + e.message)
 | 
			
		||||
  })
 | 
			
		||||
@@ -660,6 +660,8 @@ const variation = (index, item) => {
 | 
			
		||||
const send = (url, index, item) => {
 | 
			
		||||
  httpPost(url, {
 | 
			
		||||
    index: index,
 | 
			
		||||
    task_id: item.task_id,
 | 
			
		||||
    channel_id: item.channel_id,
 | 
			
		||||
    message_id: item.message_id,
 | 
			
		||||
    message_hash: item.hash,
 | 
			
		||||
    session_id: getSessionId(),
 | 
			
		||||
 
 | 
			
		||||
@@ -598,7 +598,7 @@ onMounted(() => {
 | 
			
		||||
      }
 | 
			
		||||
      runningJobs.value = _jobs
 | 
			
		||||
 | 
			
		||||
      setTimeout(() => fetchRunningJobs(userId), 5000)
 | 
			
		||||
      setTimeout(() => fetchRunningJobs(userId), 3000)
 | 
			
		||||
    }).catch(e => {
 | 
			
		||||
      ElMessage.error("获取任务失败:" + e.message)
 | 
			
		||||
    })
 | 
			
		||||
@@ -608,7 +608,7 @@ onMounted(() => {
 | 
			
		||||
  const fetchFinishJobs = (userId) => {
 | 
			
		||||
    httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
 | 
			
		||||
      finishedJobs.value = res.data
 | 
			
		||||
      setTimeout(() => fetchFinishJobs(userId), 5000)
 | 
			
		||||
      setTimeout(() => fetchFinishJobs(userId), 3000)
 | 
			
		||||
    }).catch(e => {
 | 
			
		||||
      ElMessage.error("获取任务失败:" + e.message)
 | 
			
		||||
    })
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user