mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
opt: make sure the Upscale and Variation task is assign to the same mj service with Image task
This commit is contained in:
parent
245cd3ee1a
commit
abf4f061c1
@ -16,6 +16,7 @@ const (
|
||||
// MjTask MidJourney 任务
|
||||
type MjTask struct {
|
||||
Id int `json:"id"`
|
||||
ChannelId string `json:"channel_id"`
|
||||
SessionId string `json:"session_id"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
|
@ -147,8 +147,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
type reqVo struct {
|
||||
Src string `json:"src"`
|
||||
TaskId string `json:"task_id"`
|
||||
Index int `json:"index"`
|
||||
ChannelId string `json:"channel_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
MessageHash string `json:"message_hash"`
|
||||
SessionId string `json:"session_id"`
|
||||
@ -173,12 +174,27 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskUpscale.String(),
|
||||
ReferenceId: data.MessageId,
|
||||
UserId: userId,
|
||||
TaskId: data.TaskId,
|
||||
Progress: 0,
|
||||
Prompt: data.Prompt,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.db.Create(&job); res.Error != nil {
|
||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
Id: jobId,
|
||||
SessionId: data.SessionId,
|
||||
Type: types.TaskUpscale,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
ChannelId: data.ChannelId,
|
||||
Index: data.Index,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
@ -201,6 +217,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskVariation.String(),
|
||||
ReferenceId: data.MessageId,
|
||||
UserId: userId,
|
||||
TaskId: data.TaskId,
|
||||
Progress: 0,
|
||||
Prompt: data.Prompt,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.db.Create(&job); res.Error != nil {
|
||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
Id: jobId,
|
||||
SessionId: data.SessionId,
|
||||
@ -208,6 +239,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
Index: data.Index,
|
||||
ChannelId: data.ChannelId,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
})
|
||||
|
@ -101,6 +101,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
|
||||
// parse content
|
||||
req := CBReq{
|
||||
ChannelId: m.ChannelID,
|
||||
MessageId: m.ID,
|
||||
ReferenceId: referenceId,
|
||||
Prompt: extractPrompt(m.Content),
|
||||
@ -111,7 +112,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
return
|
||||
}
|
||||
|
||||
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
|
||||
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
||||
}
|
||||
|
||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
||||
@ -132,6 +133,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
||||
}
|
||||
if strings.Contains(m.Content, "(Stopped)") {
|
||||
req := CBReq{
|
||||
ChannelId: m.ChannelID,
|
||||
MessageId: m.ID,
|
||||
ReferenceId: referenceId,
|
||||
Prompt: extractPrompt(m.Content),
|
||||
@ -142,11 +144,11 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
||||
return
|
||||
}
|
||||
|
||||
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
|
||||
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
||||
|
||||
}
|
||||
|
||||
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
||||
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
||||
progress := extractProgress(content)
|
||||
var status TaskStatus
|
||||
if progress == 100 {
|
||||
@ -168,6 +170,7 @@ func (b *Bot) addAttachment(messageId string, referenceId string, content string
|
||||
Hash: extractHashFromFilename(attachment.Filename),
|
||||
}
|
||||
req := CBReq{
|
||||
ChannelId: channelId,
|
||||
MessageId: messageId,
|
||||
ReferenceId: referenceId,
|
||||
Image: image,
|
||||
|
@ -21,7 +21,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
|
||||
if proxy != "" {
|
||||
client.SetProxyURL(proxy)
|
||||
}
|
||||
logger.Info(proxy)
|
||||
logger.Info(config)
|
||||
return &Client{client: client, config: config}
|
||||
}
|
||||
|
||||
|
@ -54,6 +54,15 @@ func (s *Service) Run() {
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
continue
|
||||
}
|
||||
|
||||
// if it's reference message, check if it's this channel's message
|
||||
if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
|
||||
s.taskQueue.RPush(task)
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -74,7 +83,6 @@ func (s *Service) Run() {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
// update the task progress
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -113,11 +121,17 @@ func (s *Service) Notify(data CBReq) {
|
||||
return
|
||||
}
|
||||
|
||||
res = s.db.Where("task_id = ?", split[0]).First(&job)
|
||||
tx := s.db.Where("task_id = ? AND progress < 100", split[0]).Session(&gorm.Session{}).Order("id ASC")
|
||||
if data.ReferenceId != "" {
|
||||
tx = tx.Where("reference_id = ?", data.ReferenceId)
|
||||
}
|
||||
res = tx.First(&job)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", res.Error)
|
||||
return
|
||||
}
|
||||
|
||||
job.ChannelId = data.ChannelId
|
||||
job.MessageId = data.MessageId
|
||||
job.ReferenceId = data.ReferenceId
|
||||
job.Progress = data.Progress
|
||||
|
@ -24,6 +24,7 @@ type InteractionsResult struct {
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
ChannelId string `json:"channel_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
ReferenceId string `json:"reference_id"`
|
||||
Image Image `json:"image"`
|
||||
|
@ -7,6 +7,7 @@ type MidJourneyJob struct {
|
||||
Type string
|
||||
UserId int
|
||||
TaskId string
|
||||
ChannelId string
|
||||
MessageId string
|
||||
ReferenceId string
|
||||
ImgURL string
|
||||
|
@ -6,6 +6,7 @@ type MidJourneyJob struct {
|
||||
Id uint `json:"id"`
|
||||
Type string `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
ChannelId string `json:"channel_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
ReferenceId string `json:"reference_id"`
|
||||
|
2
database/update-v3.2.3.sql
Normal file
2
database/update-v3.2.3.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE `chatgpt_mj_jobs` ADD `channel_id` CHAR(40) NULL DEFAULT NULL COMMENT '频道ID' AFTER `message_id`;
|
||||
ALTER TABLE `chatgpt_mj_jobs` DROP INDEX `task_id`;
|
@ -581,7 +581,7 @@ const fetchRunningJobs = (userId) => {
|
||||
}
|
||||
runningJobs.value = _jobs
|
||||
|
||||
setTimeout(() => fetchRunningJobs(userId), 5000)
|
||||
setTimeout(() => fetchRunningJobs(userId), 3000)
|
||||
|
||||
}).catch(e => {
|
||||
ElMessage.error("获取任务失败:" + e.message)
|
||||
@ -592,7 +592,7 @@ const fetchFinishJobs = (userId) => {
|
||||
// 获取已完成的任务
|
||||
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
|
||||
finishedJobs.value = res.data
|
||||
setTimeout(() => fetchFinishJobs(userId), 5000)
|
||||
setTimeout(() => fetchFinishJobs(userId), 3000)
|
||||
}).catch(e => {
|
||||
ElMessage.error("获取任务失败:" + e.message)
|
||||
})
|
||||
@ -660,6 +660,8 @@ const variation = (index, item) => {
|
||||
const send = (url, index, item) => {
|
||||
httpPost(url, {
|
||||
index: index,
|
||||
task_id: item.task_id,
|
||||
channel_id: item.channel_id,
|
||||
message_id: item.message_id,
|
||||
message_hash: item.hash,
|
||||
session_id: getSessionId(),
|
||||
|
@ -598,7 +598,7 @@ onMounted(() => {
|
||||
}
|
||||
runningJobs.value = _jobs
|
||||
|
||||
setTimeout(() => fetchRunningJobs(userId), 5000)
|
||||
setTimeout(() => fetchRunningJobs(userId), 3000)
|
||||
}).catch(e => {
|
||||
ElMessage.error("获取任务失败:" + e.message)
|
||||
})
|
||||
@ -608,7 +608,7 @@ onMounted(() => {
|
||||
const fetchFinishJobs = (userId) => {
|
||||
httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
|
||||
finishedJobs.value = res.data
|
||||
setTimeout(() => fetchFinishJobs(userId), 5000)
|
||||
setTimeout(() => fetchFinishJobs(userId), 3000)
|
||||
}).catch(e => {
|
||||
ElMessage.error("获取任务失败:" + e.message)
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user