mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
opt: make sure the Upscale and Variation task is assign to the same mj service with Image task
This commit is contained in:
parent
245cd3ee1a
commit
abf4f061c1
@ -16,6 +16,7 @@ const (
|
|||||||
// MjTask MidJourney 任务
|
// MjTask MidJourney 任务
|
||||||
type MjTask struct {
|
type MjTask struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
|
ChannelId string `json:"channel_id"`
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
|
@ -147,8 +147,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type reqVo struct {
|
type reqVo struct {
|
||||||
Src string `json:"src"`
|
TaskId string `json:"task_id"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
|
ChannelId string `json:"channel_id"`
|
||||||
MessageId string `json:"message_id"`
|
MessageId string `json:"message_id"`
|
||||||
MessageHash string `json:"message_hash"`
|
MessageHash string `json:"message_hash"`
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
@ -173,12 +174,27 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
jobId := 0
|
jobId := 0
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
|
job := model.MidJourneyJob{
|
||||||
|
Type: types.TaskUpscale.String(),
|
||||||
|
ReferenceId: data.MessageId,
|
||||||
|
UserId: userId,
|
||||||
|
TaskId: data.TaskId,
|
||||||
|
Progress: 0,
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if res := h.db.Create(&job); res.Error != nil {
|
||||||
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.MjTask{
|
h.pool.PushTask(types.MjTask{
|
||||||
Id: jobId,
|
Id: jobId,
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Type: types.TaskUpscale,
|
Type: types.TaskUpscale,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
|
ChannelId: data.ChannelId,
|
||||||
Index: data.Index,
|
Index: data.Index,
|
||||||
MessageId: data.MessageId,
|
MessageId: data.MessageId,
|
||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
@ -201,6 +217,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
jobId := 0
|
jobId := 0
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
|
|
||||||
|
job := model.MidJourneyJob{
|
||||||
|
Type: types.TaskVariation.String(),
|
||||||
|
ReferenceId: data.MessageId,
|
||||||
|
UserId: userId,
|
||||||
|
TaskId: data.TaskId,
|
||||||
|
Progress: 0,
|
||||||
|
Prompt: data.Prompt,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if res := h.db.Create(&job); res.Error != nil {
|
||||||
|
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.MjTask{
|
h.pool.PushTask(types.MjTask{
|
||||||
Id: jobId,
|
Id: jobId,
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
@ -208,6 +239,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Index: data.Index,
|
Index: data.Index,
|
||||||
|
ChannelId: data.ChannelId,
|
||||||
MessageId: data.MessageId,
|
MessageId: data.MessageId,
|
||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
})
|
})
|
||||||
|
@ -101,6 +101,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
|||||||
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
|
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
|
||||||
// parse content
|
// parse content
|
||||||
req := CBReq{
|
req := CBReq{
|
||||||
|
ChannelId: m.ChannelID,
|
||||||
MessageId: m.ID,
|
MessageId: m.ID,
|
||||||
ReferenceId: referenceId,
|
ReferenceId: referenceId,
|
||||||
Prompt: extractPrompt(m.Content),
|
Prompt: extractPrompt(m.Content),
|
||||||
@ -111,7 +112,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
|
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
||||||
@ -132,6 +133,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
|||||||
}
|
}
|
||||||
if strings.Contains(m.Content, "(Stopped)") {
|
if strings.Contains(m.Content, "(Stopped)") {
|
||||||
req := CBReq{
|
req := CBReq{
|
||||||
|
ChannelId: m.ChannelID,
|
||||||
MessageId: m.ID,
|
MessageId: m.ID,
|
||||||
ReferenceId: referenceId,
|
ReferenceId: referenceId,
|
||||||
Prompt: extractPrompt(m.Content),
|
Prompt: extractPrompt(m.Content),
|
||||||
@ -142,11 +144,11 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
|
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
||||||
progress := extractProgress(content)
|
progress := extractProgress(content)
|
||||||
var status TaskStatus
|
var status TaskStatus
|
||||||
if progress == 100 {
|
if progress == 100 {
|
||||||
@ -168,6 +170,7 @@ func (b *Bot) addAttachment(messageId string, referenceId string, content string
|
|||||||
Hash: extractHashFromFilename(attachment.Filename),
|
Hash: extractHashFromFilename(attachment.Filename),
|
||||||
}
|
}
|
||||||
req := CBReq{
|
req := CBReq{
|
||||||
|
ChannelId: channelId,
|
||||||
MessageId: messageId,
|
MessageId: messageId,
|
||||||
ReferenceId: referenceId,
|
ReferenceId: referenceId,
|
||||||
Image: image,
|
Image: image,
|
||||||
|
@ -21,7 +21,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
|
|||||||
if proxy != "" {
|
if proxy != "" {
|
||||||
client.SetProxyURL(proxy)
|
client.SetProxyURL(proxy)
|
||||||
}
|
}
|
||||||
logger.Info(proxy)
|
logger.Info(config)
|
||||||
return &Client{client: client, config: config}
|
return &Client{client: client, config: config}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +54,15 @@ func (s *Service) Run() {
|
|||||||
err := s.taskQueue.LPop(&task)
|
err := s.taskQueue.LPop(&task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("taking task with error: %v", err)
|
logger.Errorf("taking task with error: %v", err)
|
||||||
|
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if it's reference message, check if it's this channel's message
|
||||||
|
if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
|
||||||
|
s.taskQueue.RPush(task)
|
||||||
|
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||||
|
time.Sleep(time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,7 +83,6 @@ func (s *Service) Run() {
|
|||||||
logger.Error("绘画任务执行失败:", err)
|
logger.Error("绘画任务执行失败:", err)
|
||||||
// update the task progress
|
// update the task progress
|
||||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,11 +121,17 @@ func (s *Service) Notify(data CBReq) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res = s.db.Where("task_id = ?", split[0]).First(&job)
|
tx := s.db.Where("task_id = ? AND progress < 100", split[0]).Session(&gorm.Session{}).Order("id ASC")
|
||||||
|
if data.ReferenceId != "" {
|
||||||
|
tx = tx.Where("reference_id = ?", data.ReferenceId)
|
||||||
|
}
|
||||||
|
res = tx.First(&job)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Warn("非法任务:", res.Error)
|
logger.Warn("非法任务:", res.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
job.ChannelId = data.ChannelId
|
||||||
job.MessageId = data.MessageId
|
job.MessageId = data.MessageId
|
||||||
job.ReferenceId = data.ReferenceId
|
job.ReferenceId = data.ReferenceId
|
||||||
job.Progress = data.Progress
|
job.Progress = data.Progress
|
||||||
|
@ -24,6 +24,7 @@ type InteractionsResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CBReq struct {
|
type CBReq struct {
|
||||||
|
ChannelId string `json:"channel_id"`
|
||||||
MessageId string `json:"message_id"`
|
MessageId string `json:"message_id"`
|
||||||
ReferenceId string `json:"reference_id"`
|
ReferenceId string `json:"reference_id"`
|
||||||
Image Image `json:"image"`
|
Image Image `json:"image"`
|
||||||
|
@ -7,6 +7,7 @@ type MidJourneyJob struct {
|
|||||||
Type string
|
Type string
|
||||||
UserId int
|
UserId int
|
||||||
TaskId string
|
TaskId string
|
||||||
|
ChannelId string
|
||||||
MessageId string
|
MessageId string
|
||||||
ReferenceId string
|
ReferenceId string
|
||||||
ImgURL string
|
ImgURL string
|
||||||
|
@ -6,6 +6,7 @@ type MidJourneyJob struct {
|
|||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
|
ChannelId string `json:"channel_id"`
|
||||||
TaskId string `json:"task_id"`
|
TaskId string `json:"task_id"`
|
||||||
MessageId string `json:"message_id"`
|
MessageId string `json:"message_id"`
|
||||||
ReferenceId string `json:"reference_id"`
|
ReferenceId string `json:"reference_id"`
|
||||||
|
2
database/update-v3.2.3.sql
Normal file
2
database/update-v3.2.3.sql
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
ALTER TABLE `chatgpt_mj_jobs` ADD `channel_id` CHAR(40) NULL DEFAULT NULL COMMENT '频道ID' AFTER `message_id`;
|
||||||
|
ALTER TABLE `chatgpt_mj_jobs` DROP INDEX `task_id`;
|
@ -581,7 +581,7 @@ const fetchRunningJobs = (userId) => {
|
|||||||
}
|
}
|
||||||
runningJobs.value = _jobs
|
runningJobs.value = _jobs
|
||||||
|
|
||||||
setTimeout(() => fetchRunningJobs(userId), 5000)
|
setTimeout(() => fetchRunningJobs(userId), 3000)
|
||||||
|
|
||||||
}).catch(e => {
|
}).catch(e => {
|
||||||
ElMessage.error("获取任务失败:" + e.message)
|
ElMessage.error("获取任务失败:" + e.message)
|
||||||
@ -592,7 +592,7 @@ const fetchFinishJobs = (userId) => {
|
|||||||
// 获取已完成的任务
|
// 获取已完成的任务
|
||||||
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
|
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
|
||||||
finishedJobs.value = res.data
|
finishedJobs.value = res.data
|
||||||
setTimeout(() => fetchFinishJobs(userId), 5000)
|
setTimeout(() => fetchFinishJobs(userId), 3000)
|
||||||
}).catch(e => {
|
}).catch(e => {
|
||||||
ElMessage.error("获取任务失败:" + e.message)
|
ElMessage.error("获取任务失败:" + e.message)
|
||||||
})
|
})
|
||||||
@ -660,6 +660,8 @@ const variation = (index, item) => {
|
|||||||
const send = (url, index, item) => {
|
const send = (url, index, item) => {
|
||||||
httpPost(url, {
|
httpPost(url, {
|
||||||
index: index,
|
index: index,
|
||||||
|
task_id: item.task_id,
|
||||||
|
channel_id: item.channel_id,
|
||||||
message_id: item.message_id,
|
message_id: item.message_id,
|
||||||
message_hash: item.hash,
|
message_hash: item.hash,
|
||||||
session_id: getSessionId(),
|
session_id: getSessionId(),
|
||||||
|
@ -598,7 +598,7 @@ onMounted(() => {
|
|||||||
}
|
}
|
||||||
runningJobs.value = _jobs
|
runningJobs.value = _jobs
|
||||||
|
|
||||||
setTimeout(() => fetchRunningJobs(userId), 5000)
|
setTimeout(() => fetchRunningJobs(userId), 3000)
|
||||||
}).catch(e => {
|
}).catch(e => {
|
||||||
ElMessage.error("获取任务失败:" + e.message)
|
ElMessage.error("获取任务失败:" + e.message)
|
||||||
})
|
})
|
||||||
@ -608,7 +608,7 @@ onMounted(() => {
|
|||||||
const fetchFinishJobs = (userId) => {
|
const fetchFinishJobs = (userId) => {
|
||||||
httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
|
httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
|
||||||
finishedJobs.value = res.data
|
finishedJobs.value = res.data
|
||||||
setTimeout(() => fetchFinishJobs(userId), 5000)
|
setTimeout(() => fetchFinishJobs(userId), 3000)
|
||||||
}).catch(e => {
|
}).catch(e => {
|
||||||
ElMessage.error("获取任务失败:" + e.message)
|
ElMessage.error("获取任务失败:" + e.message)
|
||||||
})
|
})
|
||||||
|
Loading…
Reference in New Issue
Block a user