diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index d52ccf4e..94c49d4f 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -8,7 +8,6 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( - "fmt" "geekai/core" "geekai/core/types" "geekai/service" @@ -182,25 +181,14 @@ func (h *DallJobHandler) Remove(c *gin.Context) { } // 删除任务 - tx := h.DB.Begin() - tx.Delete(&job) - // 如果任务未完成,或者任务失败,则恢复用户算力 - if job.Progress != 100 { - err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{ - Type: types.PowerRefund, - Model: "dall-e-3", - Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), - }) - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } + err := h.DB.Delete(&job).Error + if err != nil { + resp.ERROR(c, err.Error()) + return } - tx.Commit() // remove image - err := h.uploader.GetUploadHandler().Delete(job.ImgURL) + err = h.uploader.GetUploadHandler().Delete(job.ImgURL) if err != nil { logger.Error("remove image failed: ", err) } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 532134be..a845d740 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -406,26 +406,15 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { return } - // remove job recode - tx := h.DB.Begin() - tx.Delete(&job) - // 如果任务未完成,或者任务失败,则恢复用户算力 - if job.Progress != 100 { - err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ - Type: types.PowerRefund, - Model: "mid-journey", - Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), - }) - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } + // remove job + err := h.DB.Delete(&job).Error + if err != nil { + resp.ERROR(c, err.Error()) + return } - tx.Commit() // remove image - err := h.uploader.GetUploadHandler().Delete(job.ImgURL) + err = h.uploader.GetUploadHandler().Delete(job.ImgURL) if err != nil { logger.Error("remove image failed: ", err) } diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 7c41cbb1..93ca386e 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -252,25 +252,14 @@ func (h *SdJobHandler) Remove(c *gin.Context) { } // 删除任务 - tx := h.DB.Begin() - tx.Delete(&job) - // 如果任务未完成,或者任务失败,则恢复用户算力 - if job.Progress != 100 { - err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ - Type: types.PowerRefund, - Model: "stable-diffusion", - Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d, Err: %s", job.Id, job.ErrMsg), - }) - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } + err := h.DB.Delete(&job).Error + if err != nil { + resp.ERROR(c, err.Error()) + return } - tx.Commit() // remove image - err := h.uploader.GetUploadHandler().Delete(job.ImgURL) + err = h.uploader.GetUploadHandler().Delete(job.ImgURL) if err != nil { logger.Error("remove image failed: ", err) } diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index 06bb5fa1..aa6b7c11 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -222,25 +222,11 @@ func (h *SunoHandler) Remove(c *gin.Context) { } // 删除任务 - tx := h.DB.Begin() - if err := tx.Delete(&job).Error; err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - - // 恢复用户算力 - err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ - Type: types.PowerRefund, - Model: job.ModelName, - Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), - }) + err = h.DB.Delete(&job).Error if err != nil { - tx.Rollback() resp.ERROR(c, err.Error()) return } - tx.Commit() // 删除文件 _ = h.uploader.GetUploadHandler().Delete(job.CoverURL) diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index 10967f62..e42cd9ca 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -183,25 +183,11 @@ func (h *VideoHandler) Remove(c *gin.Context) { } // 删除任务 - tx := h.DB.Begin() - if err := tx.Delete(&job).Error; err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - - // 恢复算力 - err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ - Type: types.PowerRefund, - Model: "luma", - Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), - }) + err = h.DB.Delete(&job).Error if err != nil { - tx.Rollback() resp.ERROR(c, err.Error()) return } - tx.Commit() // 删除文件 _ = h.uploader.GetUploadHandler().Delete(job.CoverURL) diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index 4e9c2504..24206a40 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -237,13 +237,9 @@ func (s *Service) CheckTaskStatus() { go func() { logger.Info("Running DALL-E task status checking ...") for { + // 检查未完成任务进度 var jobs []model.DallJob - res := s.db.Where("progress < ?", 100).Find(&jobs) - if res.Error != nil { - time.Sleep(5 * time.Second) - continue - } - + s.db.Where("progress < ?", 100).Find(&jobs) for _, job := range jobs { // 超时的任务标记为失败 if time.Now().Sub(job.CreatedAt) > time.Minute*10 { @@ -252,6 +248,21 @@ func (s *Service) CheckTaskStatus() { s.db.Updates(&job) } } + + // 找出失败的任务,并恢复其扣减算力 + s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs) + for _, job := range jobs { + err := s.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "dall-e-3", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), + }) + if err != nil { + continue + } + // 更新任务状态 + s.db.Model(&job).UpdateColumn("power", 0) + } time.Sleep(time.Second * 10) } }() diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 249c2f87..f6d08be1 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -30,10 +30,11 @@ type Service struct { db *gorm.DB wsService *service.WebsocketService uploaderManager *oss.UploaderManager + userService *service.UserService clientIds map[uint]string } -func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService) *Service { +func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService, userService *service.UserService) *Service { return &Service{ db: db, taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), @@ -42,6 +43,7 @@ func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *os wsService: wsService, uploaderManager: manager, clientIds: map[uint]string{}, + userService: userService, } } @@ -313,6 +315,21 @@ func (s *Service) SyncTaskProgress() { } } + // 找出失败的任务,并恢复其扣减算力 + s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs) + for _, job := range jobs { + err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "mid-journey", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), + }) + if err != nil { + continue + } + // 更新任务状态 + s.db.Model(&job).UpdateColumn("power", 0) + } + time.Sleep(time.Second * 5) } }() diff --git a/api/service/sd/service.go b/api/service/sd/service.go index fd196049..427b6cf8 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -34,9 +34,10 @@ type Service struct { db *gorm.DB uploadManager *oss.UploaderManager wsService *service.WebsocketService + userService *service.UserService } -func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service { return &Service{ httpClient: req.C(), taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli), @@ -44,6 +45,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD db: db, wsService: wsService, uploadManager: manager, + userService: userService, } } @@ -301,6 +303,21 @@ func (s *Service) CheckTaskStatus() { s.db.Updates(&job) } } + + // 找出失败的任务,并恢复其扣减算力 + s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs) + for _, job := range jobs { + err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "stable-diffusion", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d, Err: %s", job.Id, job.ErrMsg), + }) + if err != nil { + continue + } + // 更新任务状态 + s.db.Model(&job).UpdateColumn("power", 0) + } time.Sleep(time.Second * 5) } }() diff --git a/api/service/suno/service.go b/api/service/suno/service.go index 21e713b8..f724f5ad 100644 --- a/api/service/suno/service.go +++ b/api/service/suno/service.go @@ -36,9 +36,10 @@ type Service struct { notifyQueue *store.RedisQueue wsService *service.WebsocketService clientIds map[string]string + userService *service.UserService } -func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service { return &Service{ httpClient: req.C().SetTimeout(time.Minute * 3), db: db, @@ -47,6 +48,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien uploadManager: manager, wsService: wsService, clientIds: map[string]string{}, + userService: userService, } } @@ -384,6 +386,20 @@ func (s *Service) SyncTaskProgress() { } } + // 找出失败的任务,并恢复其扣减算力 + s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs) + for _, job := range jobs { + err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: job.ModelName, + Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + }) + if err != nil { + continue + } + // 更新任务状态 + s.db.Model(&job).UpdateColumn("power", 0) + } time.Sleep(time.Second * 10) } }() diff --git a/api/service/video/luma.go b/api/service/video/luma.go index ff164998..08e133c4 100644 --- a/api/service/video/luma.go +++ b/api/service/video/luma.go @@ -36,9 +36,10 @@ type Service struct { notifyQueue *store.RedisQueue wsService *service.WebsocketService clientIds map[uint]string + userService *service.UserService } -func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, wsService *service.WebsocketService, userService *service.UserService) *Service { return &Service{ httpClient: req.C().SetTimeout(time.Minute * 3), db: db, @@ -47,6 +48,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien wsService: wsService, uploadManager: manager, clientIds: map[uint]string{}, + userService: userService, } } @@ -286,6 +288,20 @@ func (s *Service) SyncTaskProgress() { } + // 找出失败的任务,并恢复其扣减算力 + s.db.Where("progress", service.FailTaskProgress).Where("power > ?", 0).Find(&jobs) + for _, job := range jobs { + err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "luma", + Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + }) + if err != nil { + continue + } + // 更新任务状态 + s.db.Model(&job).UpdateColumn("power", 0) + } time.Sleep(time.Second * 10) } }()