From be45f41e3449d49611ff156ad8a24638f900448c Mon Sep 17 00:00:00 2001 From: RockYang Date: Wed, 18 Sep 2024 20:33:29 +0800 Subject: [PATCH] urgent bug fix: remove suno and luma task will recharge user power --- api/handler/base_handler.go | 6 ++--- api/handler/suno_handler.go | 42 ++++++++++++++++++++++++----------- api/handler/video_handler.go | 43 +++++++++++++++++++++++++----------- 3 files changed, 62 insertions(+), 29 deletions(-) diff --git a/api/handler/base_handler.go b/api/handler/base_handler.go index 406b9b53..cb2b15ca 100644 --- a/api/handler/base_handler.go +++ b/api/handler/base_handler.go @@ -8,13 +8,13 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "errors" + "fmt" "geekai/core" "geekai/core/types" logger2 "geekai/logger" "geekai/store/model" "geekai/utils" - "errors" - "fmt" "gorm.io/gorm" "strings" @@ -85,7 +85,7 @@ func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) { } var user model.User - res := h.DB.First(&user, userId) + res := h.DB.Where("id", userId).First(&user) // 更新缓存 if res.Error == nil { c.Set(types.LoginUserCache, user) diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index 06dfce9c..7aaeab72 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -86,6 +86,17 @@ func (h *SunoHandler) Create(c *gin.Context) { return } + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + if user.Power < h.App.SysConfig.SunoPower { + resp.ERROR(c, "您的算力不足,请充值后再试!") + return + } + // 歌曲拼接 if data.SongId != "" && data.Type == 3 { var song model.SunoJob @@ -143,7 +154,7 @@ func (h *SunoHandler) Create(c *gin.Context) { }) // update user's power - err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerConsume, Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName), CreatedAt: time.Now(), @@ -225,6 +236,13 @@ func (h *SunoHandler) Remove(c *gin.Context) { resp.ERROR(c, err.Error()) return } + + // 只有失败,或者超时的任务才能删除 + if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) { + resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!") + return + } + // 删除任务 tx := h.DB.Begin() if err := tx.Delete(&job).Error; err != nil { @@ -233,18 +251,16 @@ func (h *SunoHandler) Remove(c *gin.Context) { return } - // 如果任务未完成,或者任务失败,则恢复用户算力 - if job.Progress != 100 { - err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ - Type: types.PowerRefund, - Model: job.ModelName, - Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), - }) - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } + // 恢复用户算力 + err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: job.ModelName, + Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + }) + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return } tx.Commit() diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index 4f1d81d1..31c34e57 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -22,6 +22,7 @@ import ( "github.com/gorilla/websocket" "gorm.io/gorm" "net/http" + "time" ) type VideoHandler struct { @@ -77,6 +78,18 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } + + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + if user.Power < h.App.SysConfig.LumaPower { + resp.ERROR(c, "您的算力不足,请充值后再试!") + return + } + if data.Prompt == "" { resp.ERROR(c, "prompt is needed") return @@ -113,7 +126,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { }) // update user's power - err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerConsume, Model: "luma", Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id), @@ -184,6 +197,12 @@ func (h *VideoHandler) Remove(c *gin.Context) { resp.ERROR(c, err.Error()) return } + // 只有失败或者超时的任务才能删除 + if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) { + resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!") + return + } + // 删除任务 tx := h.DB.Begin() if err := tx.Delete(&job).Error; err != nil { @@ -192,18 +211,16 @@ func (h *VideoHandler) Remove(c *gin.Context) { return } - // 如果任务未完成,或者任务失败,则恢复用户算力 - if job.Progress != 100 { - err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ - Type: types.PowerRefund, - Model: "luma", - Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), - }) - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } + // 恢复算力 + err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "luma", + Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + }) + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return } tx.Commit()