suno and luma task management funtion in admin console is ready

This commit is contained in:
RockYang
2024-10-10 17:07:40 +08:00
parent d34b785238
commit a678a11c33
17 changed files with 818 additions and 90 deletions

View File

@@ -8,9 +8,12 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
@@ -21,23 +24,25 @@ import (
type ImageHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewImageHandler(app *core.AppServer, db *gorm.DB) *ImageHandler {
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler {
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type query struct {
type imageQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_time"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// MjList Midjourney 任务列表
func (h *ImageHandler) MjList(c *gin.Context) {
var data query
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
@@ -55,9 +60,7 @@ func (h *ImageHandler) MjList(c *gin.Context) {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
@@ -83,7 +86,7 @@ func (h *ImageHandler) MjList(c *gin.Context) {
// SdList Stable Diffusion 任务列表
func (h *ImageHandler) SdList(c *gin.Context) {
var data query
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
@@ -101,9 +104,7 @@ func (h *ImageHandler) SdList(c *gin.Context) {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SdJob{}).Count(&total)
@@ -129,7 +130,7 @@ func (h *ImageHandler) SdList(c *gin.Context) {
// DallList DALL-E 任务列表
func (h *ImageHandler) DallList(c *gin.Context) {
var data query
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
@@ -147,9 +148,7 @@ func (h *ImageHandler) DallList(c *gin.Context) {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00")
end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00")
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.DallJob{}).Count(&total)
@@ -172,3 +171,84 @@ func (h *ImageHandler) DallList(c *gin.Context) {
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *ImageHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, imgURL string
var power, userId, progress int
switch tab {
case "mj":
var job model.MidJourneyJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "mid-journey"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "sd":
var job model.SdJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "stable-diffusion"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "dall":
var job model.DallJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "dall-e-3"
power = job.Power
userId = int(job.UserId)
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(imgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,200 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MediaHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *MediaHandler {
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type mediaQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// SunoList Suno 任务列表
func (h *MediaHandler) SunoList(c *gin.Context) {
var data mediaQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SunoJob{}).Count(&total)
var list []model.SunoJob
var items = make([]vo.SunoJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.SunoJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// LumaList Luma 视频任务列表
func (h *MediaHandler) LumaList(c *gin.Context) {
var data mediaQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.VideoJob{}).Count(&total)
var list []model.VideoJob
var items = make([]vo.VideoJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.VideoJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
if job.VideoURL == "" {
job.VideoURL = job.WaterURL
}
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *MediaHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, fileURL string
var power, userId, progress int
switch tab {
case "suno":
var job model.SunoJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "suno"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("SUNO 任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
fileURL = job.AudioURL
break
case "luma":
var job model.VideoJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = job.Type
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("LUMA 任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
fileURL = job.VideoURL
if fileURL == "" {
fileURL = job.WaterURL
}
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(fileURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}

View File

@@ -180,12 +180,7 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Delete(&job)
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{

View File

@@ -403,12 +403,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
// remove job recode
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Delete(&job)
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{

View File

@@ -250,18 +250,13 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Delete(&job)
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s Err: %s", job.TaskId, job.ErrMsg),
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d Err: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()

View File

@@ -156,6 +156,9 @@ func (h *VideoHandler) List(c *gin.Context) {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
if item.VideoURL == "" {
item.VideoURL = v.WaterURL
}
items = append(items, item)
}