suno and luma task management funtion in admin console is ready

This commit is contained in:
RockYang
2024-10-10 17:07:40 +08:00
parent d34b785238
commit a678a11c33
17 changed files with 818 additions and 90 deletions

View File

@@ -8,9 +8,12 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
@@ -21,23 +24,25 @@ import (
type ImageHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewImageHandler(app *core.AppServer, db *gorm.DB) *ImageHandler {
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler {
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type query struct {
type imageQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_time"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// MjList Midjourney 任务列表
func (h *ImageHandler) MjList(c *gin.Context) {
var data query
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
@@ -55,9 +60,7 @@ func (h *ImageHandler) MjList(c *gin.Context) {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
@@ -83,7 +86,7 @@ func (h *ImageHandler) MjList(c *gin.Context) {
// SdList Stable Diffusion 任务列表
func (h *ImageHandler) SdList(c *gin.Context) {
var data query
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
@@ -101,9 +104,7 @@ func (h *ImageHandler) SdList(c *gin.Context) {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SdJob{}).Count(&total)
@@ -129,7 +130,7 @@ func (h *ImageHandler) SdList(c *gin.Context) {
// DallList DALL-E 任务列表
func (h *ImageHandler) DallList(c *gin.Context) {
var data query
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
@@ -147,9 +148,7 @@ func (h *ImageHandler) DallList(c *gin.Context) {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.DallJob{}).Count(&total)
@@ -172,3 +171,84 @@ func (h *ImageHandler) DallList(c *gin.Context) {
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *ImageHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, imgURL string
var power, userId, progress int
switch tab {
case "mj":
var job model.MidJourneyJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "mid-journey"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "sd":
var job model.SdJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "stable-diffusion"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "dall":
var job model.DallJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "dall-e-3"
power = job.Power
userId = int(job.UserId)
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(imgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}