urgent bug fix: remove suno and luma task will recharge user power

This commit is contained in:
RockYang 2024-09-18 20:33:29 +08:00
parent e9ac58b1ef
commit be45f41e34
3 changed files with 62 additions and 29 deletions

View File

@ -8,13 +8,13 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"errors"
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"errors"
"fmt"
"gorm.io/gorm" "gorm.io/gorm"
"strings" "strings"
@ -85,7 +85,7 @@ func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
} }
var user model.User var user model.User
res := h.DB.First(&user, userId) res := h.DB.Where("id", userId).First(&user)
// 更新缓存 // 更新缓存
if res.Error == nil { if res.Error == nil {
c.Set(types.LoginUserCache, user) c.Set(types.LoginUserCache, user)

View File

@ -86,6 +86,17 @@ func (h *SunoHandler) Create(c *gin.Context) {
return return
} }
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.SunoPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
// 歌曲拼接 // 歌曲拼接
if data.SongId != "" && data.Type == 3 { if data.SongId != "" && data.Type == 3 {
var song model.SunoJob var song model.SunoJob
@ -143,7 +154,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
}) })
// update user's power // update user's power
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume, Type: types.PowerConsume,
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName), Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
CreatedAt: time.Now(), CreatedAt: time.Now(),
@ -225,6 +236,13 @@ func (h *SunoHandler) Remove(c *gin.Context) {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
// 只有失败,或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) {
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
return
}
// 删除任务 // 删除任务
tx := h.DB.Begin() tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil { if err := tx.Delete(&job).Error; err != nil {
@ -233,18 +251,16 @@ func (h *SunoHandler) Remove(c *gin.Context) {
return return
} }
// 如果任务未完成,或者任务失败,则恢复用户算力 // 恢复用户算力
if job.Progress != 100 { err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerRefund,
Type: types.PowerRefund, Model: job.ModelName,
Model: job.ModelName, Remark: fmt.Sprintf("Suno 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
Remark: fmt.Sprintf("Suno 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg), })
}) if err != nil {
if err != nil { tx.Rollback()
tx.Rollback() resp.ERROR(c, err.Error())
resp.ERROR(c, err.Error()) return
return
}
} }
tx.Commit() tx.Commit()

View File

@ -22,6 +22,7 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"net/http" "net/http"
"time"
) )
type VideoHandler struct { type VideoHandler struct {
@ -77,6 +78,18 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.LumaPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
if data.Prompt == "" { if data.Prompt == "" {
resp.ERROR(c, "prompt is needed") resp.ERROR(c, "prompt is needed")
return return
@ -113,7 +126,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
}) })
// update user's power // update user's power
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume, Type: types.PowerConsume,
Model: "luma", Model: "luma",
Remark: fmt.Sprintf("Luma 文生视频任务ID%d", job.Id), Remark: fmt.Sprintf("Luma 文生视频任务ID%d", job.Id),
@ -184,6 +197,12 @@ func (h *VideoHandler) Remove(c *gin.Context) {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
// 只有失败或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*30)) {
resp.ERROR(c, "只有失败和超时(30分钟)的任务才能删除!")
return
}
// 删除任务 // 删除任务
tx := h.DB.Begin() tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil { if err := tx.Delete(&job).Error; err != nil {
@ -192,18 +211,16 @@ func (h *VideoHandler) Remove(c *gin.Context) {
return return
} }
// 如果任务未完成,或者任务失败,则恢复用户算力 // 恢复算力
if job.Progress != 100 { err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerRefund,
Type: types.PowerRefund, Model: "luma",
Model: "luma", Remark: fmt.Sprintf("Luma 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
Remark: fmt.Sprintf("Luma 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg), })
}) if err != nil {
if err != nil { tx.Rollback()
tx.Rollback() resp.ERROR(c, err.Error())
resp.ERROR(c, err.Error()) return
return
}
} }
tx.Commit() tx.Commit()