use http pull message to page notify

This commit is contained in:
RockYang
2025-03-04 06:54:30 +08:00
23 changed files with 482 additions and 670 deletions

View File

@@ -3,6 +3,7 @@
## v4.2.1 ## v4.2.1
- 功能新增:新增支持可灵生成视频,支持文生视频,图生生视频。 - 功能新增:新增支持可灵生成视频,支持文生视频,图生生视频。
- 功能新增:重构所有异步任务(绘图,音乐,视频)更新方式,使用 http pull 来替代 websocket。
- 功能优化:优化 Luma 图生视频功能,支持本地上传图片和远程图片。 - 功能优化:优化 Luma 图生视频功能,支持本地上传图片和远程图片。
- Bug 修复:修复移动端聊天页面新建对话时候角色没有更模型绑定的 Bug。 - Bug 修复:修复移动端聊天页面新建对话时候角色没有更模型绑定的 Bug。
- 功能优化:优化聊天页面代码块样式,优化公式的解析。 - 功能优化:优化聊天页面代码块样式,优化公式的解析。

View File

@@ -26,7 +26,6 @@ const (
type MjTask struct { type MjTask struct {
Id uint `json:"id"` // 任务ID Id uint `json:"id"` // 任务ID
TaskId string `json:"task_id"` // 中转任务ID TaskId string `json:"task_id"` // 中转任务ID
ClientId string `json:"client_id"`
ImgArr []string `json:"img_arr"` ImgArr []string `json:"img_arr"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
@@ -44,7 +43,6 @@ type MjTask struct {
type SdTask struct { type SdTask struct {
Id int `json:"id"` // job 数据库ID Id int `json:"id"` // job 数据库ID
Type TaskType `json:"type"` Type TaskType `json:"type"`
ClientId string `json:"client_id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Params SdTaskParams `json:"params"` Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"` RetryCount int `json:"retry_count"`
@@ -52,7 +50,6 @@ type SdTask struct {
} }
type SdTaskParams struct { type SdTaskParams struct {
ClientId string `json:"client_id"` // 客户端ID
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词 Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词 NegPrompt string `json:"neg_prompt"` // 反向提示词
@@ -73,7 +70,6 @@ type SdTaskParams struct {
// DallTask DALL-E task // DallTask DALL-E task
type DallTask struct { type DallTask struct {
ClientId string `json:"client_id"`
ModelId uint `json:"model_id"` ModelId uint `json:"model_id"`
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
Id uint `json:"id"` Id uint `json:"id"`
@@ -88,7 +84,6 @@ type DallTask struct {
} }
type SunoTask struct { type SunoTask struct {
ClientId string `json:"client_id"`
Id uint `json:"id"` Id uint `json:"id"`
Channel string `json:"channel"` Channel string `json:"channel"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
@@ -113,7 +108,6 @@ const (
) )
type VideoTask struct { type VideoTask struct {
ClientId string `json:"client_id"`
Id uint `json:"id"` Id uint `json:"id"`
Channel string `json:"channel"` Channel string `json:"channel"`
UserId int `json:"user_id"` UserId int `json:"user_id"`

View File

@@ -70,7 +70,6 @@ func (h *DallJobHandler) Image(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
task := types.DallTask{ task := types.DallTask{
ClientId: data.ClientId,
UserId: uint(userId), UserId: uint(userId),
ModelId: chatModel.Id, ModelId: chatModel.Id,
ModelName: chatModel.Value, ModelName: chatModel.Value,

View File

@@ -66,7 +66,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
func (h *MidJourneyHandler) Image(c *gin.Context) { func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct { var data struct {
TaskType string `json:"task_type"` TaskType string `json:"task_type"`
ClientId string `json:"client_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"` NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"` Rate string `json:"rate"`
@@ -153,7 +152,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return return
} }
task := types.MjTask{ task := types.MjTask{
ClientId: data.ClientId,
TaskId: taskId, TaskId: taskId,
Type: types.TaskType(data.TaskType), Type: types.TaskType(data.TaskType),
Prompt: data.Prompt, Prompt: data.Prompt,
@@ -207,7 +205,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
type reqVo struct { type reqVo struct {
Index int `json:"index"` Index int `json:"index"`
ClientId string `json:"client_id"`
ChannelId string `json:"channel_id"` ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"` MessageHash string `json:"message_hash"`
@@ -229,7 +226,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
taskId, _ := h.snowflake.Next(true) taskId, _ := h.snowflake.Next(true)
task := types.MjTask{ task := types.MjTask{
ClientId: data.ClientId,
Type: types.TaskUpscale, Type: types.TaskUpscale,
UserId: userId, UserId: userId,
ChannelId: data.ChannelId, ChannelId: data.ChannelId,
@@ -286,7 +282,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
taskId, _ := h.snowflake.Next(true) taskId, _ := h.snowflake.Next(true)
task := types.MjTask{ task := types.MjTask{
Type: types.TaskVariation, Type: types.TaskVariation,
ClientId: data.ClientId,
UserId: userId, UserId: userId,
Index: data.Index, Index: data.Index,
ChannelId: data.ChannelId, ChannelId: data.ChannelId,

View File

@@ -112,8 +112,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
} }
task := types.SdTask{ task := types.SdTask{
ClientId: data.ClientId, Type: types.TaskImage,
Type: types.TaskImage,
Params: types.SdTaskParams{ Params: types.SdTaskParams{
TaskId: taskId, TaskId: taskId,
Prompt: data.Prompt, Prompt: data.Prompt,

View File

@@ -45,7 +45,6 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
func (h *SunoHandler) Create(c *gin.Context) { func (h *SunoHandler) Create(c *gin.Context) {
var data struct { var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Instrumental bool `json:"instrumental"` Instrumental bool `json:"instrumental"`
Lyrics string `json:"lyrics"` Lyrics string `json:"lyrics"`
@@ -90,7 +89,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
} }
} }
task := types.SunoTask{ task := types.SunoTask{
ClientId: data.ClientId,
UserId: int(h.GetLoginUserId(c)), UserId: int(h.GetLoginUserId(c)),
Type: data.Type, Type: data.Type,
Title: data.Title, Title: data.Title,

View File

@@ -46,7 +46,6 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
func (h *VideoHandler) LumaCreate(c *gin.Context) { func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct { var data struct {
ClientId string `json:"client_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
FirstFrameImg string `json:"first_frame_img,omitempty"` FirstFrameImg string `json:"first_frame_img,omitempty"`
EndFrameImg string `json:"end_frame_img,omitempty"` EndFrameImg string `json:"end_frame_img,omitempty"`
@@ -82,7 +81,6 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
EndImgURL: data.EndFrameImg, EndImgURL: data.EndFrameImg,
} }
task := types.VideoTask{ task := types.VideoTask{
ClientId: data.ClientId,
UserId: userId, UserId: userId,
Type: types.VideoLuma, Type: types.VideoLuma,
Prompt: data.Prompt, Prompt: data.Prompt,
@@ -124,7 +122,6 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
var data struct { var data struct {
Channel string `json:"channel"` Channel string `json:"channel"`
ClientId string `json:"client_id"`
TaskType string `json:"task_type"` // 任务类型: text2video/image2video TaskType string `json:"task_type"` // 任务类型: text2video/image2video
Model string `json:"model"` // 模型: default/anime Model string `json:"model"` // 模型: default/anime
Prompt string `json:"prompt"` // 视频描述 Prompt string `json:"prompt"` // 视频描述
@@ -173,7 +170,6 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
ImageTail: data.ImageTail, ImageTail: data.ImageTail,
} }
task := types.VideoTask{ task := types.VideoTask{
ClientId: data.ClientId,
UserId: userId, UserId: userId,
Type: types.VideoKeLing, Type: types.VideoKeLing,
Prompt: data.Prompt, Prompt: data.Prompt,
@@ -218,14 +214,14 @@ func (h *VideoHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1) page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20) pageSize := h.GetInt(c, "page_size", 20)
all := h.GetBool(c, "all") all := h.GetBool(c, "all")
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) session := h.DB.Session(&gorm.Session{})
if t != "" { if t != "" {
session = session.Where("type", t) session = session.Where("type", t)
} }
if all { if all {
session = session.Where("publish", 0).Where("progress", 100) session = session.Where("publish", 0).Where("progress", 100)
} else { } else {
session = session.Where("user_id", h.GetLoginUserId(c)) session = session.Where("user_id", userId)
} }
// 统计总数 // 统计总数
var total int64 var total int64

View File

@@ -163,7 +163,6 @@ func main() {
fx.Provide(dalle.NewService), fx.Provide(dalle.NewService),
fx.Invoke(func(s *dalle.Service) { fx.Invoke(func(s *dalle.Service) {
s.Run() s.Run()
s.CheckTaskNotify()
s.DownloadImages() s.DownloadImages()
s.CheckTaskStatus() s.CheckTaskStatus()
}), }),
@@ -182,7 +181,6 @@ func main() {
fx.Invoke(func(s *mj.Service) { fx.Invoke(func(s *mj.Service) {
s.Run() s.Run()
s.SyncTaskProgress() s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadImages() s.DownloadImages()
}), }),
@@ -191,21 +189,18 @@ func main() {
fx.Invoke(func(s *sd.Service, config *types.AppConfig) { fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
s.Run() s.Run()
s.CheckTaskStatus() s.CheckTaskStatus()
s.CheckTaskNotify()
}), }),
fx.Provide(suno.NewService), fx.Provide(suno.NewService),
fx.Invoke(func(s *suno.Service) { fx.Invoke(func(s *suno.Service) {
s.Run() s.Run()
s.SyncTaskProgress() s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles() s.DownloadFiles()
}), }),
fx.Provide(video.NewService), fx.Provide(video.NewService),
fx.Invoke(func(s *video.Service) { fx.Invoke(func(s *video.Service) {
s.Run() s.Run()
s.SyncTaskProgress() s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles() s.DownloadFiles()
}), }),
fx.Provide(service.NewUserService), fx.Provide(service.NewUserService),

View File

@@ -34,10 +34,8 @@ type Service struct {
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
userService *service.UserService userService *service.UserService
wsService *service.WebsocketService wsService *service.WebsocketService
clientIds map[uint]string
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
@@ -45,11 +43,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
httpClient: req.C().SetTimeout(time.Minute * 3), httpClient: req.C().SetTimeout(time.Minute * 3),
db: db, db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
wsService: wsService, wsService: wsService,
uploadManager: manager, uploadManager: manager,
userService: userService, userService: userService,
clientIds: map[uint]string{},
} }
} }
@@ -60,7 +56,7 @@ func (s *Service) PushTask(task types.DallTask) {
} }
func (s *Service) Run() { func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列 // 将数据库中未提交的任务加载到队列
var jobs []model.DallJob var jobs []model.DallJob
s.db.Where("progress", 0).Find(&jobs) s.db.Where("progress", 0).Find(&jobs)
for _, v := range jobs { for _, v := range jobs {
@@ -84,16 +80,16 @@ func (s *Service) Run() {
continue continue
} }
logger.Infof("handle a new DALL-E task: %+v", task) logger.Infof("handle a new DALL-E task: %+v", task)
s.clientIds[task.Id] = task.ClientId go func() {
_, err = s.Image(task, false) _, err = s.Image(task, false)
if err != nil { if err != nil {
logger.Errorf("error with image task: %v", err) logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress, "progress": service.FailTaskProgress,
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed}) }
} }()
} }
}() }()
} }
@@ -212,10 +208,9 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", fmt.Errorf("err with update database: %v", err) return "", fmt.Errorf("err with update database: %v", err)
} }
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
var content string var content string
if sync { if sync {
imgURL, err := s.downloadImage(task.Id, int(task.UserId), res.Data[0].Url) imgURL, err := s.downloadImage(task.Id, res.Data[0].Url)
if err != nil { if err != nil {
return "", fmt.Errorf("error with download image: %v", err) return "", fmt.Errorf("error with download image: %v", err)
} }
@@ -225,26 +220,6 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return content, nil return content, nil
} }
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running DALL-E task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChDall, message.Message)
}
}()
}
func (s *Service) CheckTaskStatus() { func (s *Service) CheckTaskStatus() {
go func() { go func() {
logger.Info("Running DALL-E task status checking ...") logger.Info("Running DALL-E task status checking ...")
@@ -254,7 +229,7 @@ func (s *Service) CheckTaskStatus() {
s.db.Where("progress < ?", 100).Find(&jobs) s.db.Where("progress < ?", 100).Find(&jobs)
for _, job := range jobs { for _, job := range jobs {
// 超时的任务标记为失败 // 超时的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 { if time.Since(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时" job.ErrMsg = "任务超时"
s.db.Updates(&job) s.db.Updates(&job)
@@ -301,7 +276,7 @@ func (s *Service) DownloadImages() {
} }
logger.Infof("try to download image: %s", v.OrgURL) logger.Infof("try to download image: %s", v.OrgURL)
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL) imgURL, err := s.downloadImage(v.Id, v.OrgURL)
if err != nil { if err != nil {
logger.Error("error with download image: %s, error: %v", imgURL, err) logger.Error("error with download image: %s, error: %v", imgURL, err)
continue continue
@@ -316,7 +291,7 @@ func (s *Service) DownloadImages() {
}() }()
} }
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) { func (s *Service) downloadImage(jobId uint, orgURL string) (string, error) {
// sava image // sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false) imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
if err != nil { if err != nil {
@@ -328,6 +303,5 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil { if res.Error != nil {
return "", err return "", err
} }
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
return imgURL, nil return imgURL, nil
} }

View File

@@ -15,10 +15,11 @@ import (
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"github.com/go-redis/redis/v8"
"strings" "strings"
"time" "time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -26,23 +27,19 @@ import (
type Service struct { type Service struct {
client *Client // MJ Client client *Client // MJ Client
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
wsService *service.WebsocketService wsService *service.WebsocketService
uploaderManager *oss.UploaderManager uploaderManager *oss.UploaderManager
userService *service.UserService userService *service.UserService
clientIds map[uint]string
} }
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService, userService *service.UserService) *Service { func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService, userService *service.UserService) *Service {
return &Service{ return &Service{
db: db, db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
client: client, client: client,
wsService: wsService, wsService: wsService,
uploaderManager: manager, uploaderManager: manager,
clientIds: map[uint]string{},
userService: userService, userService: userService,
} }
} }
@@ -59,7 +56,6 @@ func (s *Service) Run() {
continue continue
} }
task.Id = v.Id task.Id = v.Id
s.clientIds[task.Id] = task.ClientId
s.PushTask(task) s.PushTask(task)
} }
@@ -96,7 +92,6 @@ func (s *Service) Run() {
if task.Mode == "" { if task.Mode == "" {
task.Mode = "fast" task.Mode = "fast"
} }
s.clientIds[task.Id] = task.ClientId
var job model.MidJourneyJob var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job) tx := s.db.Where("id = ?", task.Id).First(&job)
@@ -139,7 +134,6 @@ func (s *Service) Run() {
// update the task progress // update the task progress
s.db.Updates(&job) s.db.Updates(&job)
// 任务失败,通知前端 // 任务失败,通知前端
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue continue
} }
logger.Infof("任务提交成功:%+v", res) logger.Infof("任务提交成功:%+v", res)
@@ -178,24 +172,6 @@ func GetImageHash(action string) string {
return split[len(split)-1] return split[len(split)-1]
} }
func (s *Service) CheckTaskNotify() {
go func() {
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("receive a new mj notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChMj, message.Message)
}
}()
}
func (s *Service) DownloadImages() { func (s *Service) DownloadImages() {
go func() { go func() {
var items []model.MidJourneyJob var items []model.MidJourneyJob
@@ -228,12 +204,6 @@ func (s *Service) DownloadImages() {
v.ImgURL = imgURL v.ImgURL = imgURL
s.db.Updates(&v) s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[v.Id],
UserId: v.UserId,
JobId: int(v.Id),
Message: service.TaskStatusFinished})
} }
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
@@ -259,7 +229,7 @@ func (s *Service) SyncTaskProgress() {
for _, job := range jobs { for _, job := range jobs {
// 10 分钟还没完成的任务标记为失败 // 10 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 { if time.Since(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时" job.ErrMsg = "任务超时"
s.db.Updates(&job) s.db.Updates(&job)
@@ -279,18 +249,12 @@ func (s *Service) SyncTaskProgress() {
"err_msg": task.FailReason, "err_msg": task.FailReason,
}) })
logger.Errorf("task failed: %v", task.FailReason) logger.Errorf("task failed: %v", task.FailReason)
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: service.TaskStatusFailed})
continue continue
} }
if len(task.Buttons) > 0 { if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId) job.Hash = GetImageHash(task.Buttons[0].CustomId)
} }
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
if task.ImageUrl != "" { if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl job.OrgURL = task.ImageUrl
@@ -300,19 +264,6 @@ func (s *Service) SyncTaskProgress() {
logger.Errorf("error with update database: %v", err) logger.Errorf("error with update database: %v", err)
continue continue
} }
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := service.TaskStatusRunning
if job.Progress == 100 {
message = service.TaskStatusFinished
}
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: message})
}
} }
// 找出失败的任务,并恢复其扣减算力 // 找出失败的任务,并恢复其扣减算力

View File

@@ -16,9 +16,10 @@ import (
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"github.com/go-redis/redis/v8"
"time" "time"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -30,7 +31,6 @@ var logger = logger2.GetLogger()
type Service struct { type Service struct {
httpClient *req.Client httpClient *req.Client
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
wsService *service.WebsocketService wsService *service.WebsocketService
@@ -41,7 +41,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
return &Service{ return &Service{
httpClient: req.C(), httpClient: req.C(),
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db, db: db,
wsService: wsService, wsService: wsService,
uploadManager: manager, uploadManager: manager,
@@ -102,8 +101,6 @@ func (s *Service) Run() {
"progress": service.FailTaskProgress, "progress": service.FailTaskProgress,
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
// 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue continue
} }
} }
@@ -225,15 +222,12 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished // task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
return nil return nil
default: default:
err, resp := s.checkTaskProgress(apiKey) resp, err := s.checkTaskProgress(apiKey)
// 更新任务进度 // 更新任务进度
if err == nil && resp.Progress > 0 { if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
} }
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@@ -242,7 +236,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
} }
// 执行任务 // 执行任务
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) { func (s *Service) checkTaskProgress(apiKey model.ApiKey) (*TaskProgressResp, error) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL) apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
var res TaskProgressResp var res TaskProgressResp
response, err := s.httpClient.R(). response, err := s.httpClient.R().
@@ -250,13 +244,13 @@ func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressRe
SetSuccessResult(&res). SetSuccessResult(&res).
Get(apiURL) Get(apiURL)
if err != nil { if err != nil {
return err, nil return nil, err
} }
if response.IsErrorState() { if response.IsErrorState() {
return fmt.Errorf("error http code status: %v", response.Status), nil return nil, fmt.Errorf("error http code status: %v", response.Status)
} }
return nil, &res return &res, nil
} }
func (s *Service) PushTask(task types.SdTask) { func (s *Service) PushTask(task types.SdTask) {
@@ -264,25 +258,6 @@ func (s *Service) PushTask(task types.SdTask) {
s.taskQueue.RPush(task) s.taskQueue.RPush(task)
} }
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChSd, message.Message)
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 // CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() { func (s *Service) CheckTaskStatus() {
go func() { go func() {
@@ -297,7 +272,7 @@ func (s *Service) CheckTaskStatus() {
for _, job := range jobs { for _, job := range jobs {
// 5 分钟还没完成的任务标记为失败 // 5 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*5 { if time.Since(job.CreatedAt) > time.Minute*5 {
job.Progress = service.FailTaskProgress job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时" job.ErrMsg = "任务超时"
s.db.Updates(&job) s.db.Updates(&job)

View File

@@ -18,10 +18,11 @@ import (
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"github.com/go-redis/redis/v8"
"io" "io"
"time" "time"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -35,7 +36,6 @@ type Service struct {
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
wsService *service.WebsocketService wsService *service.WebsocketService
clientIds map[string]string
userService *service.UserService userService *service.UserService
} }
@@ -47,7 +47,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli), notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
uploadManager: manager, uploadManager: manager,
wsService: wsService, wsService: wsService,
clientIds: map[string]string{},
userService: userService, userService: userService,
} }
} }
@@ -70,7 +69,6 @@ func (s *Service) Run() {
} }
task.Id = v.Id task.Id = v.Id
s.PushTask(task) s.PushTask(task)
s.clientIds[v.TaskId] = task.ClientId
} }
logger.Info("Starting Suno job consumer...") logger.Info("Starting Suno job consumer...")
go func() { go func() {
@@ -95,7 +93,6 @@ func (s *Service) Run() {
"err_msg": err.Error(), "err_msg": err.Error(),
"progress": service.FailTaskProgress, "progress": service.FailTaskProgress,
}) })
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue continue
} }
@@ -104,7 +101,6 @@ func (s *Service) Run() {
"task_id": r.Data, "task_id": r.Data,
"channel": r.Channel, "channel": r.Channel,
}) })
s.clientIds[r.Data] = task.ClientId
} }
}() }()
} }
@@ -262,27 +258,6 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
return res, nil return res, nil
} }
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
logger.Debugf("client id: %+v", s.wsService.Clients)
client := s.wsService.Clients.Get(message.ClientId)
logger.Debugf("%+v", client)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChSuno, message.Message)
}
}()
}
func (s *Service) DownloadFiles() { func (s *Service) DownloadFiles() {
go func() { go func() {
var items []model.SunoJob var items []model.SunoJob
@@ -311,7 +286,6 @@ func (s *Service) DownloadFiles() {
v.AudioURL = audioURL v.AudioURL = audioURL
v.Progress = 100 v.Progress = 100
s.db.Updates(&v) s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.TaskId], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
} }
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
@@ -377,12 +351,10 @@ func (s *Service) SyncTaskProgress() {
} }
} }
tx.Commit() tx.Commit()
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFinished})
} else if task.Data.FailReason != "" { } else if task.Data.FailReason != "" {
job.Progress = service.FailTaskProgress job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason job.ErrMsg = task.Data.FailReason
s.db.Updates(&job) s.db.Updates(&job)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[job.TaskId], UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
} }
} }

View File

@@ -20,7 +20,6 @@ import (
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"time" "time"
@@ -37,9 +36,7 @@ type Service struct {
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
wsService *service.WebsocketService wsService *service.WebsocketService
clientIds map[uint]string
userService *service.UserService userService *service.UserService
} }
@@ -48,10 +45,8 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
httpClient: req.C().SetTimeout(time.Minute * 3), httpClient: req.C().SetTimeout(time.Minute * 3),
db: db, db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
wsService: wsService, wsService: wsService,
uploadManager: manager, uploadManager: manager,
clientIds: map[uint]string{},
userService: userService, userService: userService,
} }
} }
@@ -74,7 +69,6 @@ func (s *Service) Run() {
} }
task.Id = v.Id task.Id = v.Id
s.PushTask(task) s.PushTask(task)
s.clientIds[v.Id] = task.ClientId
} }
logger.Info("Starting Video job consumer...") logger.Info("Starting Video job consumer...")
go func() { go func() {
@@ -86,10 +80,6 @@ func (s *Service) Run() {
continue continue
} }
if task.ClientId != "" {
s.clientIds[task.Id] = task.ClientId
}
if task.Type == types.VideoLuma { if task.Type == types.VideoLuma {
// translate prompt // translate prompt
if utils.HasChinese(task.Prompt) { if utils.HasChinese(task.Prompt) {
@@ -112,7 +102,6 @@ func (s *Service) Run() {
if err != nil { if err != nil {
logger.Errorf("update task with error: %v", err) logger.Errorf("update task with error: %v", err)
} }
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed, Type: types.VideoLuma})
continue continue
} }
@@ -150,7 +139,6 @@ func (s *Service) Run() {
if err != nil { if err != nil {
logger.Errorf("update task with error: %v", err) logger.Errorf("update task with error: %v", err)
} }
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed, Type: types.VideoKeLing})
continue continue
} }
@@ -170,25 +158,6 @@ func (s *Service) Run() {
}() }()
} }
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("Receive notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChLuma, message.Message)
}
}()
}
func (s *Service) DownloadFiles() { func (s *Service) DownloadFiles() {
go func() { go func() {
var items []model.VideoJob var items []model.VideoJob
@@ -232,7 +201,6 @@ func (s *Service) DownloadFiles() {
continue continue
} }
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[v.Id], UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished, Type: videoTask.Type})
} }
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
@@ -334,6 +302,12 @@ func (s *Service) SyncTaskProgress() {
logger.Errorf("更新数据库失败:%v", err) logger.Errorf("更新数据库失败:%v", err)
continue continue
} }
} else if task.TaskStatus == "failed" {
// 更新任务信息
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": task.TaskStatusMsg,
})
} }
} }
@@ -672,7 +646,7 @@ func (s *Service) QueryKeLingTask(taskId string, channel string, action string)
return VideoCallbackData{}, fmt.Errorf("unexpected status code: %d", res.StatusCode) return VideoCallbackData{}, fmt.Errorf("unexpected status code: %d", res.StatusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return VideoCallbackData{}, fmt.Errorf("failed to read response body: %w", err) return VideoCallbackData{}, fmt.Errorf("failed to read response body: %w", err)
} }

View File

@@ -6,4 +6,5 @@ ALTER TABLE `chatgpt_sd_jobs` CHANGE `prompt` `prompt` TEXT CHARACTER SET utf8mb
ALTER TABLE `chatgpt_dall_jobs` CHANGE `prompt` `prompt` TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '提示词'; ALTER TABLE `chatgpt_dall_jobs` CHANGE `prompt` `prompt` TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '提示词';
ALTER TABLE `chatgpt_files` CHANGE `name` `name` VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '文件名'; ALTER TABLE `chatgpt_files` CHANGE `name` `name` VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '文件名';
ALTER TABLE `chatgpt_chat_models` CHANGE `name` `name` VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '模型名称'; ALTER TABLE `chatgpt_chat_models` CHANGE `name` `name` VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '模型名称';
ALTER TABLE `chatgpt_api_keys` CHANGE `value` `value` VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT 'API KEY value';

View File

@@ -231,7 +231,7 @@ import { Delete, InfoFilled, Picture } from "@element-plus/icons-vue";
import { httpGet, httpPost } from "@/utils/http"; import { httpGet, httpPost } from "@/utils/http";
import { ElMessage, ElMessageBox } from "element-plus"; import { ElMessage, ElMessageBox } from "element-plus";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import { checkSession, getClientId, getSystemInfo } from "@/store/cache"; import { checkSession, getSystemInfo } from "@/store/cache";
import { useSharedStore } from "@/store/sharedata"; import { useSharedStore } from "@/store/sharedata";
import TaskList from "@/components/TaskList.vue"; import TaskList from "@/components/TaskList.vue";
import BackTop from "@/components/BackTop.vue"; import BackTop from "@/components/BackTop.vue";
@@ -267,7 +267,6 @@ const styles = [
{ name: "自然", value: "natural" }, { name: "自然", value: "natural" },
]; ];
const params = ref({ const params = ref({
client_id: getClientId(),
quality: "standard", quality: "standard",
size: "1024x1024", size: "1024x1024",
style: "vivid", style: "vivid",
@@ -276,6 +275,7 @@ const params = ref({
const finishedJobs = ref([]); const finishedJobs = ref([]);
const runningJobs = ref([]); const runningJobs = ref([]);
const allowPulling = ref(true); // 是否允许轮询
const power = ref(0); const power = ref(0);
const dallPower = ref(0); // 画一张 SD 图片消耗算力 const dallPower = ref(0); // 画一张 SD 图片消耗算力
const clipboard = ref(null); const clipboard = ref(null);
@@ -301,20 +301,6 @@ onMounted(() => {
showMessageError("获取系统配置失败:" + e.message); showMessageError("获取系统配置失败:" + e.message);
}); });
store.addMessageHandler("dall", (data) => {
// 丢弃无关消息
if (data.channel !== "dall" || data.clientId !== getClientId()) {
return;
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 0;
isOver.value = false;
fetchFinishJobs();
}
nextTick(() => fetchRunningJobs());
});
// 获取模型列表 // 获取模型列表
httpGet("/api/dall/models") httpGet("/api/dall/models")
.then((res) => { .then((res) => {
@@ -339,10 +325,16 @@ const initData = () => {
power.value = user["power"]; power.value = user["power"];
userId.value = user.id; userId.value = user.id;
isLogin.value = true; isLogin.value = true;
page.value = 0; page.value = 0;
fetchRunningJobs(); fetchRunningJobs();
fetchFinishJobs(); fetchFinishJobs();
// 轮询运行中任务
setInterval(() => {
if (allowPulling.value) {
fetchRunningJobs();
}
}, 5000);
}) })
.catch(() => {}); .catch(() => {});
}; };
@@ -354,7 +346,17 @@ const fetchRunningJobs = () => {
// 获取运行中的任务 // 获取运行中的任务
httpGet(`/api/dall/jobs?finish=false`) httpGet(`/api/dall/jobs?finish=false`)
.then((res) => { .then((res) => {
runningJobs.value = res.data.items; // 如果任务有更新,则更新已完成任务列表
if (res.data.items && res.data.items.length !== runningJobs.value.length) {
page.value = 0;
fetchFinishJobs();
}
if (res.data.items.length > 0) {
runningJobs.value = res.data.items;
} else {
allowPulling.value = false;
runningJobs.value = [];
}
}) })
.catch((e) => { .catch((e) => {
ElMessage.error("获取任务失败:" + e.message); ElMessage.error("获取任务失败:" + e.message);
@@ -410,7 +412,12 @@ const generate = () => {
.then(() => { .then(() => {
ElMessage.success("任务执行成功!"); ElMessage.success("任务执行成功!");
power.value -= dallPower.value; power.value -= dallPower.value;
fetchRunningJobs(); // 追加任务列表
runningJobs.value.push({
prompt: params.value.prompt,
progress: 0,
});
allowPulling.value = true;
}) })
.catch((e) => { .catch((e) => {
ElMessage.error("任务执行失败:" + e.message); ElMessage.error("任务执行失败:" + e.message);

View File

@@ -656,14 +656,14 @@ import Compressor from "compressorjs";
import { httpGet, httpPost } from "@/utils/http"; import { httpGet, httpPost } from "@/utils/http";
import { ElMessage, ElMessageBox, ElNotification } from "element-plus"; import { ElMessage, ElMessageBox, ElNotification } from "element-plus";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import { checkSession, getClientId, getSystemInfo } from "@/store/cache"; import { checkSession, getSystemInfo } from "@/store/cache";
import { useRouter } from "vue-router"; import { useRouter } from "vue-router";
import { getSessionId } from "@/store/session"; import { getSessionId } from "@/store/session";
import { copyObj, removeArrayItem } from "@/utils/libs"; import { copyObj, removeArrayItem } from "@/utils/libs";
import { useSharedStore } from "@/store/sharedata"; import { useSharedStore } from "@/store/sharedata";
import TaskList from "@/components/TaskList.vue"; import TaskList from "@/components/TaskList.vue";
import BackTop from "@/components/BackTop.vue"; import BackTop from "@/components/BackTop.vue";
import {closeLoading, showLoading, showMessageError} from "@/utils/dialog"; import { closeLoading, showLoading, showMessageError } from "@/utils/dialog";
const listBoxHeight = ref(0); const listBoxHeight = ref(0);
const paramBoxHeight = ref(0); const paramBoxHeight = ref(0);
@@ -746,7 +746,6 @@ const options = [
const router = useRouter(); const router = useRouter();
const initParams = { const initParams = {
client_id: getClientId(),
task_type: "image", task_type: "image",
rate: rates[0].value, rate: rates[0].value,
model: models[0].value, model: models[0].value,
@@ -772,6 +771,8 @@ const activeName = ref("txt2img");
const runningJobs = ref([]); const runningJobs = ref([]);
const finishedJobs = ref([]); const finishedJobs = ref([]);
const taskPulling = ref(true); // 任务轮询
const downloadPulling = ref(false); // 图片下载轮询
const power = ref(0); const power = ref(0);
const userId = ref(0); const userId = ref(0);
@@ -788,20 +789,6 @@ onMounted(() => {
clipboard.value.on("error", () => { clipboard.value.on("error", () => {
ElMessage.error("复制失败!"); ElMessage.error("复制失败!");
}); });
store.addMessageHandler("mj", (data) => {
// 丢弃无关消息
if (data.channel !== "mj" || data.clientId !== getClientId()) {
return;
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 0;
isOver.value = false;
fetchFinishJobs();
}
nextTick(() => fetchRunningJobs());
});
}); });
onUnmounted(() => { onUnmounted(() => {
@@ -817,8 +804,20 @@ const initData = () => {
userId.value = user.id; userId.value = user.id;
isLogin.value = true; isLogin.value = true;
page.value = 0; page.value = 0;
fetchRunningJobs();
fetchFinishJobs(); fetchFinishJobs();
setInterval(() => {
if (taskPulling.value) {
fetchRunningJobs();
}
}, 5000);
setInterval(() => {
if (downloadPulling.value) {
page.value = 0;
fetchFinishJobs();
}
}, 5000);
}) })
.catch(() => {}); .catch(() => {});
}; };
@@ -861,6 +860,14 @@ const fetchRunningJobs = () => {
} }
_jobs.push(jobs[i]); _jobs.push(jobs[i]);
} }
if (runningJobs.value.length !== _jobs.length) {
page.value = 0;
downloadPulling.value = true;
fetchFinishJobs();
}
if (_jobs.length === 0) {
taskPulling.value = false;
}
runningJobs.value = _jobs; runningJobs.value = _jobs;
}) })
.catch((e) => { .catch((e) => {
@@ -882,6 +889,7 @@ const fetchFinishJobs = () => {
httpGet(`/api/mj/jobs?finish=true&page=${page.value}&page_size=${pageSize.value}`) httpGet(`/api/mj/jobs?finish=true&page=${page.value}&page_size=${pageSize.value}`)
.then((res) => { .then((res) => {
const jobs = res.data.items; const jobs = res.data.items;
let hasDownload = false;
for (let i = 0; i < jobs.length; i++) { for (let i = 0; i < jobs.length; i++) {
if (jobs[i]["img_url"] !== "") { if (jobs[i]["img_url"] !== "") {
if (jobs[i].type === "upscale" || jobs[i].type === "swapFace") { if (jobs[i].type === "upscale" || jobs[i].type === "swapFace") {
@@ -890,16 +898,29 @@ const fetchFinishJobs = () => {
jobs[i]["thumb_url"] = jobs[i]["img_url"] + "?imageView2/1/w/480/h/480/q/75"; jobs[i]["thumb_url"] = jobs[i]["img_url"] + "?imageView2/1/w/480/h/480/q/75";
} }
} else { } else {
if (jobs[i].progress === 100) {
hasDownload = true;
}
jobs[i]["thumb_url"] = "/images/img-placeholder.jpg"; jobs[i]["thumb_url"] = "/images/img-placeholder.jpg";
} }
// 如果当前是第一页,则开启图片下载轮询
if (page.value === 1) {
downloadPulling.value = hasDownload;
}
if (jobs[i].type !== "upscale" && jobs[i].progress === 100) { if (jobs[i].type !== "upscale" && jobs[i].progress === 100) {
jobs[i]["can_opt"] = true; jobs[i]["can_opt"] = true;
} }
} }
if (jobs.length < pageSize.value) { if (jobs.length < pageSize.value) {
isOver.value = true; isOver.value = true;
} }
// 对比一下jobs和finishedJobs如果相同则不进行更新
if (JSON.stringify(jobs) === JSON.stringify(finishedJobs.value)) {
return;
}
if (page.value === 1) { if (page.value === 1) {
finishedJobs.value = jobs; finishedJobs.value = jobs;
} else { } else {
@@ -988,7 +1009,10 @@ const generate = () => {
.then(() => { .then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行..."); ElMessage.success("绘画任务推送成功,请耐心等待任务执行...");
power.value -= mjPower.value; power.value -= mjPower.value;
fetchRunningJobs(); taskPulling.value = true;
runningJobs.value.push({
progress: 0,
});
}) })
.catch((e) => { .catch((e) => {
ElMessage.error("任务推送失败:" + e.message); ElMessage.error("任务推送失败:" + e.message);
@@ -1008,7 +1032,6 @@ const variation = (index, item) => {
const send = (url, index, item) => { const send = (url, index, item) => {
httpPost(url, { httpPost(url, {
index: index, index: index,
client_id: getClientId(),
channel_id: item.channel_id, channel_id: item.channel_id,
message_id: item.message_id, message_id: item.message_id,
message_hash: item.hash, message_hash: item.hash,
@@ -1018,7 +1041,10 @@ const send = (url, index, item) => {
.then(() => { .then(() => {
ElMessage.success("任务推送成功,请耐心等待任务执行..."); ElMessage.success("任务推送成功,请耐心等待任务执行...");
power.value -= mjActionPower.value; power.value -= mjActionPower.value;
fetchRunningJobs(); taskPulling.value = true;
runningJobs.value.push({
progress: 0,
});
}) })
.catch((e) => { .catch((e) => {
ElMessage.error("任务推送失败:" + e.message); ElMessage.error("任务推送失败:" + e.message);

View File

@@ -325,7 +325,7 @@ import nodata from "@/assets/img/no-data.png";
import { httpGet, httpPost } from "@/utils/http"; import { httpGet, httpPost } from "@/utils/http";
import { ElMessage, ElMessageBox } from "element-plus"; import { ElMessage, ElMessageBox } from "element-plus";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import { checkSession, getClientId, getSystemInfo } from "@/store/cache"; import { checkSession, getSystemInfo } from "@/store/cache";
import { useRouter } from "vue-router"; import { useRouter } from "vue-router";
import { getSessionId } from "@/store/session"; import { getSessionId } from "@/store/session";
import { useSharedStore } from "@/store/sharedata"; import { useSharedStore } from "@/store/sharedata";
@@ -355,7 +355,6 @@ const samplers = ["Euler a", "DPM++ 2S a", "DPM++ 2M", "DPM++ SDE", "DPM++ 2M SD
const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"]; const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"];
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"]; const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"];
const params = ref({ const params = ref({
client_id: getClientId(),
width: 1024, width: 1024,
height: 1024, height: 1024,
sampler: samplers[0], sampler: samplers[0],
@@ -374,6 +373,7 @@ const params = ref({
const runningJobs = ref([]); const runningJobs = ref([]);
const finishedJobs = ref([]); const finishedJobs = ref([]);
const allowPulling = ref(true); // 是否允许轮询
const router = useRouter(); const router = useRouter();
// 检查是否有画同款的参数 // 检查是否有画同款的参数
const _params = router.currentRoute.value.params["copyParams"]; const _params = router.currentRoute.value.params["copyParams"];
@@ -404,20 +404,6 @@ onMounted(() => {
.catch((e) => { .catch((e) => {
ElMessage.error("获取系统配置失败:" + e.message); ElMessage.error("获取系统配置失败:" + e.message);
}); });
store.addMessageHandler("sd", (data) => {
// 丢弃无关消息
if (data.channel !== "sd" || data.clientId !== getClientId()) {
return;
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 0;
isOver.value = false;
fetchFinishJobs();
}
nextTick(() => fetchRunningJobs());
});
}); });
onUnmounted(() => { onUnmounted(() => {
@@ -434,6 +420,12 @@ const initData = () => {
page.value = 0; page.value = 0;
fetchRunningJobs(); fetchRunningJobs();
fetchFinishJobs(); fetchFinishJobs();
setInterval(() => {
if (allowPulling.value) {
fetchRunningJobs();
}
}, 5000);
}) })
.catch(() => {}); .catch(() => {});
}; };
@@ -446,6 +438,13 @@ const fetchRunningJobs = () => {
// 获取运行中的任务 // 获取运行中的任务
httpGet(`/api/sd/jobs?finish=0`) httpGet(`/api/sd/jobs?finish=0`)
.then((res) => { .then((res) => {
if (runningJobs.value.length !== res.data.items.length) {
page.value = 0;
fetchFinishJobs();
}
if (runningJobs.value.length === 0) {
allowPulling.value = false;
}
runningJobs.value = res.data.items; runningJobs.value = res.data.items;
}) })
.catch((e) => { .catch((e) => {
@@ -507,7 +506,10 @@ const generate = () => {
.then(() => { .then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行..."); ElMessage.success("绘画任务推送成功,请耐心等待任务执行...");
power.value -= sdPower.value; power.value -= sdPower.value;
fetchRunningJobs(); allowPulling.value = true;
runningJobs.value.push({
progress: 0,
});
}) })
.catch((e) => { .catch((e) => {
ElMessage.error("任务推送失败:" + e.message); ElMessage.error("任务推送失败:" + e.message);

View File

@@ -21,18 +21,18 @@
<el-row :gutter="10"> <el-row :gutter="10">
<el-col :span="8" v-for="item in rates" :key="item.value"> <el-col :span="8" v-for="item in rates" :key="item.value">
<div <div
class="flex-col items-center" class="flex-col items-center"
:class=" :class="
item.value === params.aspect_ratio item.value === params.aspect_ratio
? 'grid-content active' ? 'grid-content active'
: 'grid-content' : 'grid-content'
" "
@click="changeRate(item)" @click="changeRate(item)"
> >
<el-image <el-image
class="icon proportion" class="icon proportion"
:src="item.img" :src="item.img"
fit="cover" fit="cover"
></el-image> ></el-image>
<div class="texts">{{ item.text }}</div> <div class="texts">{{ item.text }}</div>
</div> </div>
@@ -74,10 +74,10 @@
<div class="param-line"> <div class="param-line">
<el-form-item label="创意程度"> <el-form-item label="创意程度">
<el-slider <el-slider
v-model="params.cfg_scale" v-model="params.cfg_scale"
:min="0" :min="0"
:max="1" :max="1"
:step="0.1" :step="0.1"
/> />
</el-form-item> </el-form-item>
</div> </div>
@@ -96,8 +96,8 @@
<!-- 添加运镜类型选择 --> <!-- 添加运镜类型选择 -->
<el-form-item label="运镜类型"> <el-form-item label="运镜类型">
<el-select <el-select
v-model="params.camera_control.type" v-model="params.camera_control.type"
placeholder="请选择运镜类型" placeholder="请选择运镜类型"
> >
<el-option label="请选择" value="" /> <el-option label="请选择" value="" />
<el-option label="简单运镜" value="simple" /> <el-option label="简单运镜" value="simple" />
@@ -110,49 +110,49 @@
<!-- 仅在simple模式下显示详细配置 --> <!-- 仅在simple模式下显示详细配置 -->
<div <div
class="camera-control" class="camera-control"
v-if="params.camera_control.type === 'simple'" v-if="params.camera_control.type === 'simple'"
> >
<el-form-item label="水平移动"> <el-form-item label="水平移动">
<el-slider <el-slider
v-model="params.camera_control.config.horizontal" v-model="params.camera_control.config.horizontal"
:min="-10" :min="-10"
:max="10" :max="10"
/> />
</el-form-item> </el-form-item>
<el-form-item label="垂直移动"> <el-form-item label="垂直移动">
<el-slider <el-slider
v-model="params.camera_control.config.vertical" v-model="params.camera_control.config.vertical"
:min="-10" :min="-10"
:max="10" :max="10"
/> />
</el-form-item> </el-form-item>
<el-form-item label="左右旋转"> <el-form-item label="左右旋转">
<el-slider <el-slider
v-model="params.camera_control.config.pan" v-model="params.camera_control.config.pan"
:min="-10" :min="-10"
:max="10" :max="10"
/> />
</el-form-item> </el-form-item>
<el-form-item label="上下旋转"> <el-form-item label="上下旋转">
<el-slider <el-slider
v-model="params.camera_control.config.tilt" v-model="params.camera_control.config.tilt"
:min="-10" :min="-10"
:max="10" :max="10"
/> />
</el-form-item> </el-form-item>
<el-form-item label="横向翻转"> <el-form-item label="横向翻转">
<el-slider <el-slider
v-model="params.camera_control.config.roll" v-model="params.camera_control.config.roll"
:min="-10" :min="-10"
:max="10" :max="10"
/> />
</el-form-item> </el-form-item>
<el-form-item label="镜头缩放"> <el-form-item label="镜头缩放">
<el-slider <el-slider
v-model="params.camera_control.config.zoom" v-model="params.camera_control.config.zoom"
:min="-10" :min="-10"
:max="10" :max="10"
/> />
</el-form-item> </el-form-item>
</div> </div>
@@ -166,9 +166,9 @@
<!-- 任务类型选择 --> <!-- 任务类型选择 -->
<div class="param-line"> <div class="param-line">
<el-tabs <el-tabs
v-model="params.task_type" v-model="params.task_type"
@tab-change="tabChange" @tab-change="tabChange"
class="title-tabs" class="title-tabs"
> >
<el-tab-pane label="文生视频" name="text2video"> <el-tab-pane label="文生视频" name="text2video">
<div class="text">使用文字描述想要生成视频的内容</div> <div class="text">使用文字描述想要生成视频的内容</div>
@@ -186,18 +186,19 @@
<div class="generation-area"> <div class="generation-area">
<div v-if="params.task_type === 'text2video'" class="text2video"> <div v-if="params.task_type === 'text2video'" class="text2video">
<el-input <el-input
v-model="params.prompt" v-model="params.prompt"
type="textarea" type="textarea"
:autosize="{ minRows: 4, maxRows: 6 }" maxlength="500"
placeholder="请在此输入视频提示词,您也可以点击下面的提示词助手生成视频提示词" :autosize="{ minRows: 4, maxRows: 6 }"
placeholder="请在此输入视频提示词,您也可以点击下面的提示词助手生成视频提示词"
/> />
<el-row class="text-info"> <el-row class="text-info">
<el-button <el-button
class="generate-btn" class="generate-btn"
@click="generatePrompt" @click="generatePrompt"
:loading="isGenerating" :loading="isGenerating"
size="small" size="small"
color="#5865f2" color="#5865f2"
> >
<i class="iconfont icon-chuangzuo"></i> <i class="iconfont icon-chuangzuo"></i>
生成专业视频提示词 生成专业视频提示词
@@ -210,16 +211,16 @@
<div class="upload-box img-uploader"> <div class="upload-box img-uploader">
<h4>起始帧</h4> <h4>起始帧</h4>
<el-upload <el-upload
class="uploader img-uploader" class="uploader img-uploader"
:auto-upload="true" :auto-upload="true"
:show-file-list="false" :show-file-list="false"
:http-request="uploadStartImage" :http-request="uploadStartImage"
accept=".jpg,.png,.jpeg" accept=".jpg,.png,.jpeg"
> >
<img <img
v-if="params.image" v-if="params.image"
:src="params.image" :src="params.image"
class="preview" class="preview"
/> />
<el-icon v-else class="upload-icon"><Plus /></el-icon> <el-icon v-else class="upload-icon"><Plus /></el-icon>
</el-upload> </el-upload>
@@ -227,16 +228,16 @@
<div class="upload-box img-uploader"> <div class="upload-box img-uploader">
<h4>结束帧</h4> <h4>结束帧</h4>
<el-upload <el-upload
class="uploader" class="uploader"
:auto-upload="true" :auto-upload="true"
:show-file-list="false" :show-file-list="false"
:http-request="uploadEndImage" :http-request="uploadEndImage"
accept=".jpg,.png,.jpeg" accept=".jpg,.png,.jpeg"
> >
<img <img
v-if="params.image_tail" v-if="params.image_tail"
:src="params.image_tail" :src="params.image_tail"
class="preview" class="preview"
/> />
<el-icon v-else class="upload-icon"><Plus /></el-icon> <el-icon v-else class="upload-icon"><Plus /></el-icon>
</el-upload> </el-upload>
@@ -247,8 +248,8 @@
<div class="flex-row justify-start items-center"> <div class="flex-row justify-start items-center">
<span>提示词</span> <span>提示词</span>
<el-tooltip <el-tooltip
content="输入你想要的内容,用逗号分割" content="输入你想要的内容,用逗号分割"
placement="right" placement="right"
> >
<el-icon> <el-icon>
<InfoFilled /> <InfoFilled />
@@ -259,10 +260,10 @@
</div> </div>
<div class="param-line pt"> <div class="param-line pt">
<el-input <el-input
v-model="params.prompt" v-model="params.prompt"
type="textarea" type="textarea"
:autosize="{ minRows: 4, maxRows: 6 }" :autosize="{ minRows: 4, maxRows: 6 }"
placeholder="描述视频画面细节" placeholder="描述视频画面细节"
/> />
</div> </div>
</div> </div>
@@ -273,8 +274,8 @@
<div class="flex-row justify-start items-center"> <div class="flex-row justify-start items-center">
<span>不希望出现的内容可选</span> <span>不希望出现的内容可选</span>
<el-tooltip <el-tooltip
content="不想出现在图片上的元素(例如:树,建筑)" content="不想出现在图片上的元素(例如:树,建筑)"
placement="right" placement="right"
> >
<el-icon> <el-icon>
<InfoFilled /> <InfoFilled />
@@ -285,21 +286,21 @@
</div> </div>
<div class="param-line pt"> <div class="param-line pt">
<el-input <el-input
v-model="params.negative_prompt" v-model="params.negative_prompt"
type="textarea" type="textarea"
:autosize="{ minRows: 4, maxRows: 6 }" :autosize="{ minRows: 4, maxRows: 6 }"
placeholder="请在此输入你不希望出现在视频上的内容" placeholder="请在此输入你不希望出现在视频上的内容"
/> />
</div> </div>
<!-- 算力显示 --> <!-- 算力显示 -->
<el-row class="text-info"> <el-row class="text-info">
<el-text type="primary" <el-text type="primary"
>每次生成视频消耗 >每次生成视频消耗
<el-text type="warning">{{ powerCost }}算力;</el-text> </el-text <el-text type="warning">{{ powerCost }}算力;</el-text> </el-text
>&nbsp;&nbsp; >&nbsp;&nbsp;
<el-text type="primary" <el-text type="primary"
>当前可用算力<el-text type="warning">{{ >当前可用算力<el-text type="warning">{{
availablePower availablePower
}}</el-text></el-text }}</el-text></el-text
> >
@@ -308,7 +309,7 @@
<!-- 生成按钮 --> <!-- 生成按钮 -->
<div class="submit-btn"> <div class="submit-btn">
<el-button type="primary" :dark="false" @click="generate" round <el-button type="primary" :dark="false" @click="generate" round
>立即生成</el-button >立即生成</el-button
> >
</div> </div>
</div> </div>
@@ -332,153 +333,85 @@
<h2 class="record-title pt">创作记录</h2> <h2 class="record-title pt">创作记录</h2>
<!-- 已完成的任务 --> <!-- 已完成的任务 -->
<v3-waterfall <v3-waterfall
:virtual-time="200" :key="waterfallKey"
:distance-to-scroll="150" :list="finishedTasks"
:key="waterfallKey" @scrollReachBottom="fetchTasks"
:list="finishedTasks" :gap="20"
@scrollReachBottom="fetchTasks" :bottomGap="20"
:gap="8" :colWidth="300"
:bottomGap="8" :distanceToScroll="100"
:colWidth="300" :isLoading="loading"
:distanceToScroll="100" :isOver="isOver"
:isLoading="loading" class="task-waterfall"
:isOver="isOver"
class="task-waterfall"
> >
<template #default="slotProp"> <template #default="slotProp">
<!-- 视频成功渲染部分 -->
<div <div
class="job-item-box" class="job-item-box"
:class="{ :class="{
processing: slotProp.item.progress < 100, processing: slotProp.item.progress < 100,
error: slotProp.item.progress === 101 error: slotProp.item.progress === 101
}" }"
> >
<video <video
v-if=" v-if="slotProp.item.progress === 100"
slotProp.item.progress >= 100 && slotProp.item.video_url class="preview"
" :src="slotProp.item.video_url"
class="preview" @click="previewVideo(slotProp.item)"
:src="slotProp.item.video_url" controls
@click="previewVideo(slotProp.item)"
controls
:style="{
width: '100%',
height: `${slotProp.item.height || 400}px`
}"
></video> ></video>
<!-- 失败/无图状态 --> <div v-else class="status-overlay">
<div
v-else
class="error-container"
:style="{
width: '100%',
height: `${slotProp.item.height || 300}px`,
objectFit: 'cover'
}"
>
<div <div
v-if=" v-if="slotProp.item.progress === 101"
slotProp.item.progress >= 100 && class="error-status"
!slotProp.item.video_url
"
class="error-status"
> >
<img :src="failed" /> <el-icon><CloseBold /></el-icon>
生成失败 任务失败
</div> </div>
<div v-else class="processing-status"> <div v-else class="processing-status">
<el-progress <el-progress
:percentage="slotProp.item.progress" :percentage="slotProp.item.progress"
:stroke-width="12" :stroke-width="12"
status="success" status="success"
/> />
</div> </div>
</div> </div>
<div class="tools-box">
<div class="tools"> <div class="tools">
<el-button <el-button
type="primary" v-if="slotProp.item.progress === 100"
v-if="
slotProp.item.progress >= 100 &&
slotProp.item.video_url
"
@click="downloadVideo(slotProp.item)" @click="downloadVideo(slotProp.item)"
> >
<el-icon><Download /></el-icon> <el-icon><Download /></el-icon>
</el-button> </el-button>
<el-button type="danger" @click="deleteTask(slotProp.item)">
<div <el-icon><Delete /></el-icon>
class="show-prompt" </el-button>
v-if=" <div class="show-prompt">
slotProp.item.progress >= 100 && <el-popover
!slotProp.item.video_url &&
slotProp.item.err_msg
"
>
<el-popover
placement="left" placement="left"
title="提示词"
:width="240" :width="240"
trigger="hover" trigger="hover"
> >
<template #reference> <template #reference>
<el-icon class="chromefilled error-txt" <el-icon class="chromefilled">
><WarnTriangleFilled <ChromeFilled />
/></el-icon> </el-icon>
</template> </template>
<template #default> <template #default>
<div class="top-tips"> <div class="mj-list-item-prompt">
<span>错误详细信息</span <span>{{ slotProp.item.prompt }}</span>
><el-icon <el-icon
class="copy-prompt-kl" class="copy-prompt-mj"
:data-clipboard-text="slotProp.item.prompt" :data-clipboard-text="slotProp.item.prompt"
> >
<DocumentCopy /> <DocumentCopy />
</el-icon>
</div>
<div class="mj-list-item-prompt">
<span>{{ slotProp.item.prompt }}</span>
</div>
</template>
</el-popover>
</div>
<el-button
type="danger"
@click="deleteTask(slotProp.item)"
>
<el-icon><Delete /></el-icon>
</el-button>
<div class="show-prompt">
<el-popover
placement="left"
:width="240"
trigger="hover"
>
<template #reference>
<el-icon class="chromefilled">
<ChromeFilled />
</el-icon> </el-icon>
</template> </div>
</template>
<template #default> </el-popover>
<div class="top-tips">
<span>提示词</span
><el-icon
class="copy-prompt-kl"
:data-clipboard-text="slotProp.item.prompt"
>
<DocumentCopy />
</el-icon>
</div>
<div class="mj-list-item-prompt">
<span>{{ slotProp.item.prompt }}</span>
</div>
</template>
</el-popover>
</div>
</div> </div>
</div> </div>
</div> </div>
@@ -498,19 +431,18 @@
<!-- 视频预览对话框 --> <!-- 视频预览对话框 -->
<el-dialog v-model="previewVisible" title="视频预览" width="80%"> <el-dialog v-model="previewVisible" title="视频预览" width="80%">
<video <video
v-if="currentVideo" v-if="currentVideo"
:src="currentVideo" :src="currentVideo"
controls controls
style="width: 100%" style="width: 100%"
></video> ></video>
</el-dialog> </el-dialog>
</div> </div>
</template> </template>
<script setup> <script setup>
import failed from "@/assets/img/failed.png";
import TaskList from "@/components/TaskList.vue"; import TaskList from "@/components/TaskList.vue";
import { ref, reactive, onMounted, onUnmounted, watch, computed } from "vue"; import { ref, reactive, onMounted, onUnmounted, watch } from "vue";
import { import {
Plus, Plus,
Delete, Delete,
@@ -518,12 +450,11 @@ import {
ChromeFilled, ChromeFilled,
DocumentCopy, DocumentCopy,
Download, Download,
WarnTriangleFilled CloseBold
} from "@element-plus/icons-vue"; } from "@element-plus/icons-vue";
import { httpGet, httpPost, httpDownload } from "@/utils/http"; import { httpGet, httpPost, httpDownload } from "@/utils/http";
import { ElMessage, ElMessageBox } from "element-plus"; import { ElMessage, ElMessageBox } from "element-plus";
import { getClientId, checkSession } from "@/store/cache"; import { checkSession } from "@/store/cache";
import Clipboard from "clipboard";
import { import {
closeLoading, closeLoading,
@@ -535,7 +466,6 @@ import { replaceImg } from "@/utils/libs";
// 参数设置 // 参数设置
const params = reactive({ const params = reactive({
client_id: getClientId(),
task_type: "text2video", task_type: "text2video",
model: "default", model: "default",
prompt: "", prompt: "",
@@ -559,7 +489,7 @@ const params = reactive({
image_tail: "" image_tail: ""
}); });
const rates = [ const rates = [
{ css: "square", value: "1:1", text: "1:1", img: "/images/mj/rate_1_1.png" }, {css: "square", value: "1:1", text: "1:1", img: "/images/mj/rate_1_1.png"},
{ {
css: "size16-9", css: "size16-9",
@@ -594,6 +524,7 @@ const currentPage = ref(1);
const previewVisible = ref(false); const previewVisible = ref(false);
const currentVideo = ref(""); const currentVideo = ref("");
const isOver = ref(false); const isOver = ref(false);
const pullTask = ref(true);
// 方法定义 // 方法定义
@@ -634,7 +565,7 @@ const generatePrompt = async () => {
} }
isGenerating.value = true; isGenerating.value = true;
try { try {
const res = await httpPost("/api/prompt/video", { prompt: params.prompt }); const res = await httpPost("/api/prompt/video", {prompt: params.prompt});
params.prompt = res.data; params.prompt = res.data;
} catch (e) { } catch (e) {
showMessageError("生成失败: " + e.message); showMessageError("生成失败: " + e.message);
@@ -647,6 +578,10 @@ const generate = async () => {
if (!params.prompt?.trim()) { if (!params.prompt?.trim()) {
return ElMessage.error("请输入视频描述"); return ElMessage.error("请输入视频描述");
} }
// 提示词长度不能超过 500
if (params.prompt.length > 500) {
return ElMessage.error("视频描述不能超过 500 个字符");
}
// if (params.task_type === "image2video" && !params.image) { // if (params.task_type === "image2video" && !params.image) {
// return ElMessage.error("请上传起始帧图片"); // return ElMessage.error("请上传起始帧图片");
// } // }
@@ -659,7 +594,9 @@ const generate = async () => {
await httpPost("/api/video/keling/create", params); await httpPost("/api/video/keling/create", params);
showMessageOK("任务创建成功"); showMessageOK("任务创建成功");
// 立即获取最新数据 // 立即获取最新数据
fetchTasks(); await fetchTasks();
// 开启任务轮询
pullTask.value = true;
} catch (e) { } catch (e) {
showMessageError("创建失败: " + e.message); showMessageError("创建失败: " + e.message);
} finally { } finally {
@@ -686,26 +623,25 @@ const fetchTasks = async () => {
// 精确任务过滤逻辑 // 精确任务过滤逻辑
const data = res.data || {}; const data = res.data || {};
const newRunning = data.items.filter( const newRunning = data.items.filter(
(task) => task.progress < 100 && task.progress !== 101 (task) => task.progress < 100 && task.progress !== 101
); );
runningTasks.value = [...runningTasks.value, ...newRunning]; runningTasks.value = [...runningTasks.value, ...newRunning];
// 如果运行中的任务为零,则停止轮询
if (newRunning.length === 0) {
pullTask.value = false;
}
const newfinished = data.items.filter((task) => task.progress >= 100); const newFinished = data.items.filter((task) => task.progress >= 100);
const finishedList = [...finishedTasks.value, ...newfinished]; finishedTasks.value = [...finishedTasks.value, ...newFinished];
finishedTasks.value = finishedList.map((item) => ({
...item,
height: 300 * (Math.random() * 0.4 + 0.6) // 生成300~420px随机高度
}));
console.log("finishedTasks: " + finishedList);
// // 强制刷新瀑布流 // // 强制刷新瀑布流
waterfallKey.value = Date.now(); waterfallKey.value = Date.now();
total.value = data.total; total.value = data.total;
const shouldLoadNextPage = const shouldLoadNextPage =
runningTasks.value.length > 0 || runningTasks.value.length > 0 ||
(runningTasks.value.length === 0 && (runningTasks.value.length === 0 &&
finishedTasks.value.length < total.value); finishedTasks.value.length < total.value);
if (shouldLoadNextPage) { if (shouldLoadNextPage) {
currentPage.value++; currentPage.value++;
@@ -758,9 +694,9 @@ const downloadVideo = async (task) => {
const deleteTask = async (task) => { const deleteTask = async (task) => {
try { try {
await ElMessageBox.confirm("确定要删除该任务吗?"); await ElMessageBox.confirm("确定要删除该任务吗?");
await httpGet("/api/video/remove", { id: task.id }); await httpGet("/api/video/remove", {id: task.id});
showMessageOK("删除成功"); showMessageOK("删除成功");
fetchTasks(); await fetchTasks();
} catch (e) { } catch (e) {
if (e !== "cancel") { if (e !== "cancel") {
showMessageError("删除失败: " + e.message); showMessageError("删除失败: " + e.message);
@@ -768,31 +704,24 @@ const deleteTask = async (task) => {
} }
}; };
fetchTasks();
const clipboard = ref(null);
// 生命周期钩子 // 生命周期钩子
onMounted(async () => { onMounted(async () => {
checkSession() checkSession()
.then(async () => { .then(async () => {
isLogin.value = true; isLogin.value = true;
console.log("mounted-isLogin-可以继续", isLogin.value); console.log("mounted-isLogin-可以继续", isLogin.value);
await fetchTasks();
})
.catch(() => {});
// fetchTasks(); await fetchTasks();
clipboard.value = new Clipboard(".copy-prompt-kl"); setInterval(() => {
clipboard.value.on("success", () => { if (pullTask.value) {
ElMessage.success("复制成功!"); fetchTasks();
}); }
}, 5000)
clipboard.value.on("error", () => { })
ElMessage.error("复制失败!"); .catch(() => {
}); });
}); });
onUnmounted(() => { onUnmounted(() => {
clipboard.value.destroy();
}); });
// 监听任务状态变化 // 监听任务状态变化
watch([runningTasks, finishedTasks], () => { watch([runningTasks, finishedTasks], () => {
@@ -805,25 +734,12 @@ watch([runningTasks, finishedTasks], () => {
<style lang="stylus" scoped> <style lang="stylus" scoped>
@import "@/assets/css/image-keling.styl" @import "@/assets/css/image-keling.styl"
@import "@/assets/css/custom-scroll.styl" @import "@/assets/css/custom-scroll.styl"
.copy-prompt-kl{ .mj-list-item-prompt {
cursor pointer
}
.top-tips{
height: 30px
font-size: 18px
line-height: 30px
display: flex
align-items: center;
span{
margin-right: 10px
color:#000
}
}
.mj-list-item-prompt{
max-height: 600px; max-height: 600px;
overflow: auto; overflow: auto;
} }
:deep(.running-job-box .image-slot){
:deep(.running-job-box .image-slot) {
display: flex display: flex
align-items: center align-items: center
flex-direction: column; flex-direction: column;
@@ -836,98 +752,88 @@ watch([runningTasks, finishedTasks], () => {
text-align: center text-align: center
width: 200px; width: 200px;
height: 200px; height: 200px;
.iconfont{
font-size: 45px; .iconfont {
font-size: 45px
} }
span{
span {
font-size: 15px font-size: 15px
} }
} }
.record-title .record-title
padding:1rem 0 padding: 1rem 0
.type-btn-group .type-btn-group
margin-bottom: 20px margin-bottom: 20px
.task-waterfall .task-waterfall
margin: 0 -10px margin: 0 -10px
transition: opacity 0.3s ease transition: opacity 0.3s ease
.job-item-box .job-item-box
position: relative position: relative
background: #f5f5f5; transition: transform 0.3s ease
transition: height 0.3s ease;
overflow: hidden overflow: hidden
// margin: 10px margin: 10px
// border: 1px solid #666; border: 1px solid #666;
// padding: 6px; padding: 6px;
border-radius: 6px; border-radius: 6px;
break-inside: avoid break-inside: avoid
video video
min-height: 200px; min-height: 200px;
width: 100%;
object-fit: cover;
.chromefilled .chromefilled
font-size: 24px; font-size: 24px;
color: #fff;
&.error-txt{
color: #ffff54;
cursor:pointer;
}
.show-prompt .show-prompt
display: flex; display: flex;
align-items: center; align-items: center;
&:hover
// transform: translateY(-3px)
.tools-box{
display:block
background:rgba(0, 0, 0, 0.3)
width : 100%;
}
.error-container &:hover
position: relative transform: translateY(-3px)
background: var(--bg-deep-color)
.status-overlay
position: absolute
top: 0
left: 0
right: 0
bottom: 0
background: rgba(0, 0, 0, 0.7)
display: flex display: flex
align-items: center align-items: center
justify-content: center justify-content: center
img{
width: 66%;
height: 66%;
object-fit: cover;
margin: 0 auto;
}
.error-status .error-status
color: #c2c6cc color: #ff4d4f
text-align: center text-align: center
font-size: 24px
.el-icon
font-size: 24px
display: block
margin-bottom: 8px
.processing-status .processing-status
width: 80% width: 80%
.el-progress .el-progress
margin: 0 auto margin: 0 auto
.tools-box{
display:none
position:absolute;
top: 0;
right: 0;
}
.tools .tools
align-items: center; align-items: center;
justify-content: flex-end; justify-content: flex-end;
display: flex display: flex
gap: 5px gap: 8px
margin: 5px 5px 5px 0; margin: 5px 0 0
.el-button + .el-button
.el-button+.el-button
margin-left: 0px; margin-left: 0px;
.el-button .el-button
padding: 3px padding: 6px
border-radius: 50% border-radius: 50%
</style> </style>

View File

@@ -124,14 +124,13 @@ import nodata from "@/assets/img/no-data.png";
import { onMounted, onUnmounted, reactive, ref } from "vue"; import { onMounted, onUnmounted, reactive, ref } from "vue";
import { CircleCloseFilled } from "@element-plus/icons-vue"; import { CircleCloseFilled } from "@element-plus/icons-vue";
import { httpDownload, httpPost, httpGet } from "@/utils/http"; import { httpDownload, httpPost, httpGet } from "@/utils/http";
import { checkSession, getClientId } from "@/store/cache"; import { checkSession } from "@/store/cache";
import { closeLoading, showLoading, showMessageError, showMessageOK } from "@/utils/dialog"; import { closeLoading, showLoading, showMessageError, showMessageOK } from "@/utils/dialog";
import { replaceImg } from "@/utils/libs"; import { replaceImg } from "@/utils/libs";
import { ElMessage, ElMessageBox } from "element-plus"; import { ElMessage, ElMessageBox } from "element-plus";
import BlackSwitch from "@/components/ui/BlackSwitch.vue"; import BlackSwitch from "@/components/ui/BlackSwitch.vue";
import Generating from "@/components/ui/Generating.vue"; import Generating from "@/components/ui/Generating.vue";
import BlackDialog from "@/components/ui/BlackDialog.vue"; import BlackDialog from "@/components/ui/BlackDialog.vue";
import { useSharedStore } from "@/store/sharedata";
const showDialog = ref(false); const showDialog = ref(false);
const currentVideoUrl = ref(""); const currentVideoUrl = ref("");
@@ -139,7 +138,6 @@ const row = ref(1);
const images = ref([]); const images = ref([]);
const formData = reactive({ const formData = reactive({
client_id: getClientId(),
prompt: "", prompt: "",
expand_prompt: false, expand_prompt: false,
loop: false, loop: false,
@@ -147,32 +145,28 @@ const formData = reactive({
end_frame_img: "", end_frame_img: "",
}); });
const store = useSharedStore(); const loading = ref(false);
const list = ref([]);
const noData = ref(true);
const page = ref(1);
const pageSize = ref(10);
const total = ref(0);
const taskPulling = ref(true);
onMounted(() => { onMounted(() => {
checkSession().then(() => { checkSession().then(() => {
fetchData(1); fetchData(1);
}).catch(() => {}); setInterval(() => {
if (taskPulling.value) {
store.addMessageHandler("luma", (data) => { fetchData(1);
// 丢弃无关消息 }
if (data.channel !== "luma" || data.clientId !== getClientId()) { }, 5000);
return;
}
if (data.body === "FINISH" || data.body === "FAIL") {
fetchData(1);
}
}); });
}); });
onUnmounted(() => {
store.removeMessageHandler("luma");
});
const download = (item) => { const download = (item) => {
const url = replaceImg(item.video_url); const url = replaceImg(item.video_url);
const downloadURL = `${process.env.VUE_APP_API_HOST}/api/download?url=${url}`; const downloadURL = `${process.env.VUE_APP_API_HOST}/api/download?url=${url}`;
// parse filename
const urlObj = new URL(url); const urlObj = new URL(url);
const fileName = urlObj.pathname.split("/").pop(); const fileName = urlObj.pathname.split("/").pop();
item.downloading = true; item.downloading = true;
@@ -231,17 +225,16 @@ const publishJob = (item) => {
const upload = (file) => { const upload = (file) => {
const formData = new FormData(); const formData = new FormData();
formData.append("file", file.file, file.name); formData.append("file", file.file, file.name);
showLoading("正在上传文件...") showLoading("正在上传文件...");
// 执行上传操作
httpPost("/api/upload", formData) httpPost("/api/upload", formData)
.then((res) => { .then((res) => {
images.value.push(res.data.url); images.value.push(res.data.url);
ElMessage.success({ message: "上传成功", duration: 500 }); ElMessage.success({ message: "上传成功", duration: 500 });
closeLoading() closeLoading();
}) })
.catch((e) => { .catch((e) => {
ElMessage.error("图片上传失败:" + e.message); ElMessage.error("图片上传失败:" + e.message);
closeLoading() closeLoading();
}); });
}; };
@@ -252,12 +245,7 @@ const remove = (img) => {
const switchReverse = () => { const switchReverse = () => {
images.value = images.value.reverse(); images.value = images.value.reverse();
}; };
const loading = ref(false);
const list = ref([]);
const noData = ref(true);
const page = ref(1);
const pageSize = ref(10);
const total = ref(0);
const fetchData = (_page) => { const fetchData = (_page) => {
if (_page) { if (_page) {
page.value = _page; page.value = _page;
@@ -269,8 +257,19 @@ const fetchData = (_page) => {
}) })
.then((res) => { .then((res) => {
total.value = res.data.total; total.value = res.data.total;
let needPull = false;
const items = [];
for (let v of res.data.items) {
if (v.progress === 0 || v.progress === 102) {
needPull = true;
}
items.push(v);
}
loading.value = false; loading.value = false;
list.value = res.data.items; taskPulling.value = needPull;
if (JSON.stringify(list.value) !== JSON.stringify(items)) {
list.value = items;
}
noData.value = list.value.length === 0; noData.value = list.value.length === 0;
}) })
.catch(() => { .catch(() => {
@@ -279,7 +278,6 @@ const fetchData = (_page) => {
}); });
}; };
// 创建视频
const create = () => { const create = () => {
const len = images.value.length; const len = images.value.length;
if (len) { if (len) {
@@ -292,6 +290,7 @@ const create = () => {
httpPost("/api/video/luma/create", formData) httpPost("/api/video/luma/create", formData)
.then(() => { .then(() => {
fetchData(1); fetchData(1);
taskPulling.value = true;
showMessageOK("创建任务成功"); showMessageOK("创建任务成功");
}) })
.catch((e) => { .catch((e) => {

View File

@@ -278,8 +278,8 @@ import BlackInput from "@/components/ui/BlackInput.vue";
import MusicPlayer from "@/components/MusicPlayer.vue"; import MusicPlayer from "@/components/MusicPlayer.vue";
import { compact } from "lodash"; import { compact } from "lodash";
import { httpDownload, httpGet, httpPost } from "@/utils/http"; import { httpDownload, httpGet, httpPost } from "@/utils/http";
import {closeLoading, showLoading, showMessageError, showMessageOK} from "@/utils/dialog"; import { closeLoading, showLoading, showMessageError, showMessageOK } from "@/utils/dialog";
import { checkSession, getClientId } from "@/store/cache"; import { checkSession } from "@/store/cache";
import { ElMessage, ElMessageBox } from "element-plus"; import { ElMessage, ElMessageBox } from "element-plus";
import { formatTime, replaceImg } from "@/utils/libs"; import { formatTime, replaceImg } from "@/utils/libs";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
@@ -313,7 +313,6 @@ const tags = ref([
{ label: "嘻哈", value: "hip hop" }, { label: "嘻哈", value: "hip hop" },
]); ]);
const data = ref({ const data = ref({
client_id: getClientId(),
model: "chirp-v3-0", model: "chirp-v3-0",
tags: "", tags: "",
lyrics: "", lyrics: "",
@@ -330,6 +329,7 @@ const playList = ref([]);
const playerRef = ref(null); const playerRef = ref(null);
const showPlayer = ref(false); const showPlayer = ref(false);
const list = ref([]); const list = ref([]);
const taskPulling = ref(true);
const btnText = ref("开始创作"); const btnText = ref("开始创作");
const refSong = ref(null); const refSong = ref(null);
const showDialog = ref(false); const showDialog = ref(false);
@@ -350,19 +350,13 @@ onMounted(() => {
checkSession() checkSession()
.then(() => { .then(() => {
fetchData(1); fetchData(1);
setInterval(() => {
if (taskPulling.value) {
fetchData(1);
}
}, 5000);
}) })
.catch(() => {}); .catch(() => {});
store.addMessageHandler("suno", (data) => {
// 丢弃无关消息
if (data.channel !== "suno" || data.clientId !== getClientId()) {
return;
}
if (data.body === "FINISH" || data.body === "FAIL") {
fetchData(1);
}
});
}); });
onUnmounted(() => { onUnmounted(() => {
@@ -381,15 +375,23 @@ const fetchData = (_page) => {
httpGet("/api/suno/list", { page: page.value, page_size: pageSize.value }) httpGet("/api/suno/list", { page: page.value, page_size: pageSize.value })
.then((res) => { .then((res) => {
total.value = res.data.total; total.value = res.data.total;
let needPull = false;
const items = []; const items = [];
for (let v of res.data.items) { for (let v of res.data.items) {
if (v.progress === 100) { if (v.progress === 100) {
v.major_model_version = v["raw_data"]["major_model_version"]; v.major_model_version = v["raw_data"]["major_model_version"];
} }
if (v.progress === 0 || v.progress === 102) {
needPull = true;
}
items.push(v); items.push(v);
} }
loading.value = false; loading.value = false;
list.value = items; taskPulling.value = needPull;
// 如果任务有变化,则刷新任务列表
if (JSON.stringify(list.value) !== JSON.stringify(items)) {
list.value = items;
}
noData.value = list.value.length === 0; noData.value = list.value.length === 0;
}) })
.catch((e) => { .catch((e) => {
@@ -425,6 +427,7 @@ const create = () => {
httpPost("/api/suno/create", data.value) httpPost("/api/suno/create", data.value)
.then(() => { .then(() => {
fetchData(1); fetchData(1);
taskPulling.value = true;
showMessageOK("创建任务成功"); showMessageOK("创建任务成功");
}) })
.catch((e) => { .catch((e) => {
@@ -437,6 +440,7 @@ const merge = (item) => {
httpPost("/api/suno/create", { song_id: item.song_id, type: 3 }) httpPost("/api/suno/create", { song_id: item.song_id, type: 3 })
.then(() => { .then(() => {
fetchData(1); fetchData(1);
taskPulling.value = true;
showMessageOK("创建任务成功"); showMessageOK("创建任务成功");
}) })
.catch((e) => { .catch((e) => {
@@ -606,11 +610,11 @@ const uploadCover = (file) => {
.then((res) => { .then((res) => {
editData.value.cover = res.data.url; editData.value.cover = res.data.url;
ElMessage.success({ message: "上传成功", duration: 500 }); ElMessage.success({ message: "上传成功", duration: 500 });
closeLoading() closeLoading();
}) })
.catch((e) => { .catch((e) => {
ElMessage.error("图片上传失败:" + e.message); ElMessage.error("图片上传失败:" + e.message);
closeLoading() closeLoading();
}); });
}, },
error(err) { error(err) {

View File

@@ -133,7 +133,7 @@ import { onMounted, onUnmounted, ref } from "vue";
import { Delete } from "@element-plus/icons-vue"; import { Delete } from "@element-plus/icons-vue";
import { httpGet, httpPost } from "@/utils/http"; import { httpGet, httpPost } from "@/utils/http";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import { checkSession, getClientId, getSystemInfo } from "@/store/cache"; import { checkSession, getSystemInfo } from "@/store/cache";
import { useRouter } from "vue-router"; import { useRouter } from "vue-router";
import { getSessionId } from "@/store/session"; import { getSessionId } from "@/store/session";
import { showConfirmDialog, showDialog, showFailToast, showImagePreview, showNotify, showSuccessToast, showToast } from "vant"; import { showConfirmDialog, showDialog, showFailToast, showImagePreview, showNotify, showSuccessToast, showToast } from "vant";
@@ -174,7 +174,6 @@ const styles = [
{ text: "自然", value: "natural" }, { text: "自然", value: "natural" },
]; ];
const params = ref({ const params = ref({
client_id: getClientId(),
quality: qualities[0].value, quality: qualities[0].value,
size: sizes[0].value, size: sizes[0].value,
style: styles[0].value, style: styles[0].value,
@@ -191,6 +190,7 @@ const showModelPicker = ref(false);
const runningJobs = ref([]); const runningJobs = ref([]);
const finishedJobs = ref([]); const finishedJobs = ref([]);
const allowPulling = ref(true); // 是否允许轮询
const router = useRouter(); const router = useRouter();
const power = ref(0); const power = ref(0);
const dallPower = ref(0); // 画一张 DALL 图片消耗算力 const dallPower = ref(0); // 画一张 DALL 图片消耗算力
@@ -220,17 +220,6 @@ onMounted(() => {
showNotify({ type: "danger", message: "获取系统配置失败:" + e.message }); showNotify({ type: "danger", message: "获取系统配置失败:" + e.message });
}); });
store.addMessageHandler("dall", (data) => {
if (data.channel !== "dall" || data.clientId !== getClientId()) {
return;
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 1;
fetchFinishJobs(1);
}
fetchRunningJobs();
});
// 获取模型列表 // 获取模型列表
httpGet("/api/dall/models") httpGet("/api/dall/models")
.then((res) => { .then((res) => {
@@ -257,6 +246,12 @@ const initData = () => {
isLogin.value = true; isLogin.value = true;
fetchRunningJobs(); fetchRunningJobs();
fetchFinishJobs(1); fetchFinishJobs(1);
setInterval(() => {
if (allowPulling.value) {
fetchRunningJobs();
}
}, 5000);
}) })
.catch(() => { .catch(() => {
loading.value = false; loading.value = false;
@@ -267,6 +262,12 @@ const fetchRunningJobs = () => {
// 获取运行中的任务 // 获取运行中的任务
httpGet(`/api/dall/jobs?finish=0`) httpGet(`/api/dall/jobs?finish=0`)
.then((res) => { .then((res) => {
if (runningJobs.value.length !== res.data.items.length) {
fetchFinishJobs(1);
}
if (res.data.items.length === 0) {
allowPulling.value = false;
}
runningJobs.value = res.data.items; runningJobs.value = res.data.items;
}) })
.catch((e) => { .catch((e) => {
@@ -333,7 +334,10 @@ const generate = () => {
.then(() => { .then(() => {
showSuccessToast("绘画任务推送成功,请耐心等待任务执行..."); showSuccessToast("绘画任务推送成功,请耐心等待任务执行...");
power.value -= dallPower.value; power.value -= dallPower.value;
fetchRunningJobs(); allowPulling.value = true;
runningJobs.value.push({
progress: 0,
});
}) })
.catch((e) => { .catch((e) => {
showFailToast("任务推送失败:" + e.message); showFailToast("任务推送失败:" + e.message);

View File

@@ -255,7 +255,7 @@ import { showConfirmDialog, showFailToast, showImagePreview, showNotify, showSuc
import { httpGet, httpPost } from "@/utils/http"; import { httpGet, httpPost } from "@/utils/http";
import Compressor from "compressorjs"; import Compressor from "compressorjs";
import { getSessionId } from "@/store/session"; import { getSessionId } from "@/store/session";
import { checkSession, getClientId, getSystemInfo } from "@/store/cache"; import { checkSession, getSystemInfo } from "@/store/cache";
import { useRouter } from "vue-router"; import { useRouter } from "vue-router";
import { Delete } from "@element-plus/icons-vue"; import { Delete } from "@element-plus/icons-vue";
import { showLoginDialog } from "@/utils/libs"; import { showLoginDialog } from "@/utils/libs";
@@ -282,7 +282,6 @@ const models = [
]; ];
const imgList = ref([]); const imgList = ref([]);
const params = ref({ const params = ref({
client_id: getClientId(),
task_type: "image", task_type: "image",
rate: rates[0].value, rate: rates[0].value,
model: models[0].value, model: models[0].value,
@@ -310,6 +309,8 @@ const isLogin = ref(false);
const prompt = ref(""); const prompt = ref("");
const store = useSharedStore(); const store = useSharedStore();
const clipboard = ref(null); const clipboard = ref(null);
const taskPulling = ref(true);
const downloadPulling = ref(false);
onMounted(() => { onMounted(() => {
clipboard.value = new Clipboard(".copy-prompt"); clipboard.value = new Clipboard(".copy-prompt");
@@ -327,21 +328,23 @@ onMounted(() => {
isLogin.value = true; isLogin.value = true;
fetchRunningJobs(); fetchRunningJobs();
fetchFinishJobs(1); fetchFinishJobs(1);
setInterval(() => {
if (taskPulling.value) {
fetchRunningJobs();
}
}, 5000);
setInterval(() => {
if (downloadPulling.value) {
page.value = 1;
fetchFinishJobs(1);
}
}, 5000);
}) })
.catch(() => { .catch(() => {
// router.push('/login') // router.push('/login')
}); });
store.addMessageHandler("mj", (data) => {
if (data.channel !== "mj" || data.clientId !== getClientId()) {
return;
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 1;
fetchFinishJobs(1);
}
fetchRunningJobs();
});
}); });
onUnmounted(() => { onUnmounted(() => {
@@ -362,6 +365,10 @@ getSystemInfo()
// 获取运行中的任务 // 获取运行中的任务
const fetchRunningJobs = (userId) => { const fetchRunningJobs = (userId) => {
if (!isLogin.value) {
return;
}
httpGet(`/api/mj/jobs?finish=0&user_id=${userId}`) httpGet(`/api/mj/jobs?finish=0&user_id=${userId}`)
.then((res) => { .then((res) => {
const jobs = res.data.items; const jobs = res.data.items;
@@ -381,6 +388,14 @@ const fetchRunningJobs = (userId) => {
} }
_jobs.push(jobs[i]); _jobs.push(jobs[i]);
} }
if (runningJobs.value.length !== _jobs.length) {
page.value = 1;
downloadPulling.value = true;
fetchFinishJobs(1);
}
if (_jobs.length === 0) {
taskPulling.value = false;
}
runningJobs.value = _jobs; runningJobs.value = _jobs;
}) })
.catch((e) => { .catch((e) => {
@@ -394,11 +409,16 @@ const error = ref(false);
const page = ref(0); const page = ref(0);
const pageSize = ref(10); const pageSize = ref(10);
const fetchFinishJobs = (page) => { const fetchFinishJobs = (page) => {
if (!isLogin.value) {
return;
}
loading.value = true; loading.value = true;
// 获取已完成的任务 // 获取已完成的任务
httpGet(`/api/mj/jobs?finish=1&page=${page}&page_size=${pageSize.value}`) httpGet(`/api/mj/jobs?finish=1&page=${page}&page_size=${pageSize.value}`)
.then((res) => { .then((res) => {
const jobs = res.data.items; const jobs = res.data.items;
let hasDownload = false;
for (let i = 0; i < jobs.length; i++) { for (let i = 0; i < jobs.length; i++) {
if (jobs[i].type === "upscale" || jobs[i].type === "swapFace") { if (jobs[i].type === "upscale" || jobs[i].type === "swapFace") {
jobs[i]["thumb_url"] = jobs[i]["img_url"] + "?imageView2/1/w/480/h/600/q/75"; jobs[i]["thumb_url"] = jobs[i]["img_url"] + "?imageView2/1/w/480/h/600/q/75";
@@ -406,13 +426,23 @@ const fetchFinishJobs = (page) => {
jobs[i]["thumb_url"] = jobs[i]["img_url"] + "?imageView2/1/w/480/h/480/q/75"; jobs[i]["thumb_url"] = jobs[i]["img_url"] + "?imageView2/1/w/480/h/480/q/75";
} }
if (jobs[i]["img_url"] === "" && jobs[i].progress === 100) {
hasDownload = true;
}
if (jobs[i].type !== "upscale" && jobs[i].progress === 100) { if (jobs[i].type !== "upscale" && jobs[i].progress === 100) {
jobs[i]["can_opt"] = true; jobs[i]["can_opt"] = true;
} }
} }
if (page === 1) {
downloadPulling.value = hasDownload;
}
if (jobs.length < pageSize.value) { if (jobs.length < pageSize.value) {
finished.value = true; finished.value = true;
} }
if (page === 1) { if (page === 1) {
finishedJobs.value = jobs; finishedJobs.value = jobs;
} else { } else {
@@ -480,7 +510,6 @@ const uploadImg = (file) => {
const send = (url, index, item) => { const send = (url, index, item) => {
httpPost(url, { httpPost(url, {
client_id: getClientId(),
index: index, index: index,
channel_id: item.channel_id, channel_id: item.channel_id,
message_id: item.message_id, message_id: item.message_id,
@@ -491,7 +520,9 @@ const send = (url, index, item) => {
.then(() => { .then(() => {
showSuccessToast("任务推送成功,请耐心等待任务执行..."); showSuccessToast("任务推送成功,请耐心等待任务执行...");
power.value -= mjActionPower.value; power.value -= mjActionPower.value;
fetchRunningJobs(); runningJobs.value.push({
progress: 0,
});
}) })
.catch((e) => { .catch((e) => {
showFailToast("任务推送失败:" + e.message); showFailToast("任务推送失败:" + e.message);
@@ -525,7 +556,10 @@ const generate = () => {
.then(() => { .then(() => {
showToast("绘画任务推送成功,请耐心等待任务执行"); showToast("绘画任务推送成功,请耐心等待任务执行");
power.value -= mjPower.value; power.value -= mjPower.value;
fetchRunningJobs(); taskPulling.value = true;
runningJobs.value.push({
progress: 0,
});
}) })
.catch((e) => { .catch((e) => {
showFailToast("任务推送失败:" + e.message); showFailToast("任务推送失败:" + e.message);

View File

@@ -175,7 +175,7 @@ import { onMounted, onUnmounted, ref } from "vue";
import { Delete } from "@element-plus/icons-vue"; import { Delete } from "@element-plus/icons-vue";
import { httpGet, httpPost } from "@/utils/http"; import { httpGet, httpPost } from "@/utils/http";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import { checkSession, getClientId, getSystemInfo } from "@/store/cache"; import { checkSession, getSystemInfo } from "@/store/cache";
import { useRouter } from "vue-router"; import { useRouter } from "vue-router";
import { getSessionId } from "@/store/session"; import { getSessionId } from "@/store/session";
import { showConfirmDialog, showDialog, showFailToast, showImagePreview, showNotify, showSuccessToast, showToast } from "vant"; import { showConfirmDialog, showDialog, showFailToast, showImagePreview, showNotify, showSuccessToast, showToast } from "vant";
@@ -211,7 +211,6 @@ const upscaleAlgArr = ref([
const showUpscalePicker = ref(false); const showUpscalePicker = ref(false);
const params = ref({ const params = ref({
client_id: getClientId(),
width: 1024, width: 1024,
height: 1024, height: 1024,
sampler: samplers.value[0].value, sampler: samplers.value[0].value,
@@ -229,6 +228,7 @@ const params = ref({
const runningJobs = ref([]); const runningJobs = ref([]);
const finishedJobs = ref([]); const finishedJobs = ref([]);
const allowPulling = ref(true); // 是否允许轮询
const router = useRouter(); const router = useRouter();
// 检查是否有画同款的参数 // 检查是否有画同款的参数
const _params = router.currentRoute.value.params["copyParams"]; const _params = router.currentRoute.value.params["copyParams"];
@@ -260,17 +260,6 @@ onMounted(() => {
.catch((e) => { .catch((e) => {
showNotify({ type: "danger", message: "获取系统配置失败:" + e.message }); showNotify({ type: "danger", message: "获取系统配置失败:" + e.message });
}); });
store.addMessageHandler("sd", (data) => {
if (data.channel !== "sd" || data.clientId !== getClientId()) {
return;
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 1;
fetchFinishJobs(1);
}
fetchRunningJobs();
});
}); });
onUnmounted(() => { onUnmounted(() => {
@@ -286,6 +275,12 @@ const initData = () => {
isLogin.value = true; isLogin.value = true;
fetchRunningJobs(); fetchRunningJobs();
fetchFinishJobs(1); fetchFinishJobs(1);
setInterval(() => {
if (allowPulling.value) {
fetchRunningJobs();
}
}, 5000);
}) })
.catch(() => { .catch(() => {
loading.value = false; loading.value = false;
@@ -309,6 +304,14 @@ const fetchRunningJobs = () => {
} }
_jobs.push(jobs[i]); _jobs.push(jobs[i]);
} }
if (runningJobs.value.length !== _jobs.length) {
fetchFinishJobs(1);
}
if (runningJobs.value.length === 0) {
allowPulling.value = false;
}
runningJobs.value = _jobs; runningJobs.value = _jobs;
}) })
.catch((e) => { .catch((e) => {
@@ -375,7 +378,10 @@ const generate = () => {
.then(() => { .then(() => {
showSuccessToast("绘画任务推送成功,请耐心等待任务执行..."); showSuccessToast("绘画任务推送成功,请耐心等待任务执行...");
power.value -= sdPower.value; power.value -= sdPower.value;
fetchRunningJobs(); allowPulling.value = true;
runningJobs.value.push({
progress: 0,
});
}) })
.catch((e) => { .catch((e) => {
showFailToast("任务推送失败:" + e.message); showFailToast("任务推送失败:" + e.message);