From a678a11c3372d253f0817b1c3e9e32aac57216ad Mon Sep 17 00:00:00 2001 From: RockYang Date: Thu, 10 Oct 2024 17:07:40 +0800 Subject: [PATCH] suno and luma task management funtion in admin console is ready --- CHANGELOG.md | 1 + api/handler/admin/image_handler.go | 112 +++++- api/handler/admin/media_handler.go | 200 ++++++++++ api/handler/dalle_handler.go | 7 +- api/handler/mj_handler.go | 7 +- api/handler/sd_handler.go | 9 +- api/handler/video_handler.go | 3 + api/main.go | 8 + web/.env.development | 2 +- web/.env.production | 2 +- web/src/assets/css/chat-plus.styl | 2 +- web/src/components/admin/AdminSidebar.vue | 5 + web/src/router.js | 6 + web/src/views/ChatPlus.vue | 2 +- web/src/views/admin/ChatList.vue | 41 +- web/src/views/admin/ImageList.vue | 51 +-- web/src/views/admin/Medias.vue | 450 ++++++++++++++++++++++ 17 files changed, 818 insertions(+), 90 deletions(-) create mode 100644 api/handler/admin/media_handler.go create mode 100644 web/src/views/admin/Medias.vue diff --git a/CHANGELOG.md b/CHANGELOG.md index d6aaf544..923e36b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## v4.1.6 * 功能优化:优化MysQL容器配置文档,解决MysQL容器资源占用过高问题 * 功能新增:管理后台增加AI绘图任务管理,可在管理后台浏览和删除用户的绘图任务 +* 功能新增:管理后台增加Suno和Luma任务管理功能 ## v4.1.5 * 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接 diff --git a/api/handler/admin/image_handler.go b/api/handler/admin/image_handler.go index 6241d9aa..3685042a 100644 --- a/api/handler/admin/image_handler.go +++ b/api/handler/admin/image_handler.go @@ -8,9 +8,12 @@ package admin // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "fmt" "geekai/core" "geekai/core/types" "geekai/handler" + "geekai/service" + "geekai/service/oss" "geekai/store/model" "geekai/store/vo" "geekai/utils" @@ -21,23 +24,25 @@ import ( type ImageHandler struct { handler.BaseHandler + userService *service.UserService + uploader *oss.UploaderManager } -func NewImageHandler(app *core.AppServer, db *gorm.DB) *ImageHandler { - return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler { + return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager} } -type query struct { +type imageQuery struct { Prompt string `json:"prompt"` Username string `json:"username"` - CreatedAt []string `json:"created_time"` + CreatedAt []string `json:"created_at"` Page int `json:"page"` PageSize int `json:"page_size"` } // MjList Midjourney 任务列表 func (h *ImageHandler) MjList(c *gin.Context) { - var data query + var data imageQuery if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) return @@ -55,9 +60,7 @@ func (h *ImageHandler) MjList(c *gin.Context) { session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%") } if len(data.CreatedAt) == 2 { - start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00") - end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00") - session = session.Where("created_at >= ? AND created_at <= ?", start, end) + session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1]) } var total int64 session.Model(&model.MidJourneyJob{}).Count(&total) @@ -83,7 +86,7 @@ func (h *ImageHandler) MjList(c *gin.Context) { // SdList Stable Diffusion 任务列表 func (h *ImageHandler) SdList(c *gin.Context) { - var data query + var data imageQuery if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) return @@ -101,9 +104,7 @@ func (h *ImageHandler) SdList(c *gin.Context) { session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%") } if len(data.CreatedAt) == 2 { - start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00") - end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00") - session = session.Where("created_at >= ? AND created_at <= ?", start, end) + session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1]) } var total int64 session.Model(&model.SdJob{}).Count(&total) @@ -129,7 +130,7 @@ func (h *ImageHandler) SdList(c *gin.Context) { // DallList DALL-E 任务列表 func (h *ImageHandler) DallList(c *gin.Context) { - var data query + var data imageQuery if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) return @@ -147,9 +148,7 @@ func (h *ImageHandler) DallList(c *gin.Context) { session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%") } if len(data.CreatedAt) == 2 { - start := utils.Str2stamp(data.CreatedAt[0] + " 00:00:00") - end := utils.Str2stamp(data.CreatedAt[1] + " 00:00:00") - session = session.Where("created_at >= ? AND created_at <= ?", start, end) + session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1]) } var total int64 session.Model(&model.DallJob{}).Count(&total) @@ -172,3 +171,84 @@ func (h *ImageHandler) DallList(c *gin.Context) { resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items)) } + +func (h *ImageHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + tab := c.Query("tab") + + tx := h.DB.Begin() + var md, remark, imgURL string + var power, userId, progress int + switch tab { + case "mj": + var job model.MidJourneyJob + if err := h.DB.Where("id", id).First(&job).Error; err != nil { + resp.ERROR(c, "记录不存在") + return + } + tx.Delete(&job) + md = "mid-journey" + power = job.Power + userId = job.UserId + remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg) + progress = job.Progress + imgURL = job.ImgURL + break + case "sd": + var job model.SdJob + if res := h.DB.Where("id", id).First(&job); res.Error != nil { + resp.ERROR(c, "记录不存在") + return + } + + // 删除任务 + tx.Delete(&job) + md = "stable-diffusion" + power = job.Power + userId = job.UserId + remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg) + progress = job.Progress + imgURL = job.ImgURL + break + case "dall": + var job model.DallJob + if res := h.DB.Where("id", id).First(&job); res.Error != nil { + resp.ERROR(c, "记录不存在") + return + } + + // 删除任务 + tx.Delete(&job) + md = "dall-e-3" + power = job.Power + userId = int(job.UserId) + remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg) + progress = job.Progress + imgURL = job.ImgURL + break + default: + resp.ERROR(c, types.InvalidArgs) + return + } + + if progress != 100 { + err := h.userService.IncreasePower(userId, power, model.PowerLog{ + Type: types.PowerRefund, + Model: md, + Remark: remark, + }) + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } + } + tx.Commit() + // remove image + err := h.uploader.GetUploadHandler().Delete(imgURL) + if err != nil { + logger.Error("remove image failed: ", err) + } + + resp.SUCCESS(c) +} diff --git a/api/handler/admin/media_handler.go b/api/handler/admin/media_handler.go new file mode 100644 index 00000000..4ce42c06 --- /dev/null +++ b/api/handler/admin/media_handler.go @@ -0,0 +1,200 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/service" + "geekai/service/oss" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type MediaHandler struct { + handler.BaseHandler + userService *service.UserService + uploader *oss.UploaderManager +} + +func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *MediaHandler { + return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager} +} + +type mediaQuery struct { + Prompt string `json:"prompt"` + Username string `json:"username"` + CreatedAt []string `json:"created_at"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +// SunoList Suno 任务列表 +func (h *MediaHandler) SunoList(c *gin.Context) { + var data mediaQuery + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + session := h.DB.Session(&gorm.Session{}) + if data.Username != "" { + var user model.User + err := h.DB.Where("username", data.Username).First(&user).Error + if err == nil { + session = session.Where("user_id", user.Id) + } + } + if data.Prompt != "" { + session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%") + } + if len(data.CreatedAt) == 2 { + session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1]) + } + var total int64 + session.Model(&model.SunoJob{}).Count(&total) + var list []model.SunoJob + var items = make([]vo.SunoJob, 0) + offset := (data.Page - 1) * data.PageSize + err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error + if err == nil { + // 填充数据 + for _, item := range list { + var job vo.SunoJob + err = utils.CopyObject(item, &job) + if err != nil { + continue + } + job.CreatedAt = item.CreatedAt.Unix() + items = append(items, job) + } + } + + resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items)) +} + +// LumaList Luma 视频任务列表 +func (h *MediaHandler) LumaList(c *gin.Context) { + var data mediaQuery + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + session := h.DB.Session(&gorm.Session{}) + if data.Username != "" { + var user model.User + err := h.DB.Where("username", data.Username).First(&user).Error + if err == nil { + session = session.Where("user_id", user.Id) + } + } + if data.Prompt != "" { + session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%") + } + if len(data.CreatedAt) == 2 { + session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1]) + } + var total int64 + session.Model(&model.VideoJob{}).Count(&total) + var list []model.VideoJob + var items = make([]vo.VideoJob, 0) + offset := (data.Page - 1) * data.PageSize + err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error + if err == nil { + // 填充数据 + for _, item := range list { + var job vo.VideoJob + err = utils.CopyObject(item, &job) + if err != nil { + continue + } + job.CreatedAt = item.CreatedAt.Unix() + if job.VideoURL == "" { + job.VideoURL = job.WaterURL + } + items = append(items, job) + } + } + + resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items)) +} + +func (h *MediaHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + tab := c.Query("tab") + + tx := h.DB.Begin() + var md, remark, fileURL string + var power, userId, progress int + switch tab { + case "suno": + var job model.SunoJob + if err := h.DB.Where("id", id).First(&job).Error; err != nil { + resp.ERROR(c, "记录不存在") + return + } + tx.Delete(&job) + md = "suno" + power = job.Power + userId = job.UserId + remark = fmt.Sprintf("SUNO 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg) + progress = job.Progress + fileURL = job.AudioURL + break + case "luma": + var job model.VideoJob + if res := h.DB.Where("id", id).First(&job); res.Error != nil { + resp.ERROR(c, "记录不存在") + return + } + + // 删除任务 + tx.Delete(&job) + md = job.Type + power = job.Power + userId = job.UserId + remark = fmt.Sprintf("LUMA 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg) + progress = job.Progress + fileURL = job.VideoURL + if fileURL == "" { + fileURL = job.WaterURL + } + break + default: + resp.ERROR(c, types.InvalidArgs) + return + } + + if progress != 100 { + err := h.userService.IncreasePower(userId, power, model.PowerLog{ + Type: types.PowerRefund, + Model: md, + Remark: remark, + }) + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } + } + tx.Commit() + // remove image + err := h.uploader.GetUploadHandler().Delete(fileURL) + if err != nil { + logger.Error("remove image failed: ", err) + } + + resp.SUCCESS(c) +} diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index eb46710f..404c9704 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -180,12 +180,7 @@ func (h *DallJobHandler) Remove(c *gin.Context) { // 删除任务 tx := h.DB.Begin() - if err := tx.Delete(&job).Error; err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - + tx.Delete(&job) // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{ diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 8feaeb53..858a0d89 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -403,12 +403,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { // remove job recode tx := h.DB.Begin() - if err := tx.Delete(&job).Error; err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - + tx.Delete(&job) // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 9f5345dd..437dceac 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -250,18 +250,13 @@ func (h *SdJobHandler) Remove(c *gin.Context) { // 删除任务 tx := h.DB.Begin() - if err := tx.Delete(&job).Error; err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - + tx.Delete(&job) // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ Type: types.PowerRefund, Model: "stable-diffusion", - Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s, Err: %s", job.TaskId, job.ErrMsg), + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d, Err: %s", job.Id, job.ErrMsg), }) if err != nil { tx.Rollback() diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index 02bf21cb..aaa0bd86 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -156,6 +156,9 @@ func (h *VideoHandler) List(c *gin.Context) { continue } item.CreatedAt = v.CreatedAt.Unix() + if item.VideoURL == "" { + item.VideoURL = v.WaterURL + } items = append(items, item) } diff --git a/api/main.go b/api/main.go index 0810b342..bb2a57e8 100644 --- a/api/main.go +++ b/api/main.go @@ -545,6 +545,14 @@ func main() { group.POST("/list/mj", h.MjList) group.POST("/list/sd", h.SdList) group.POST("/list/dall", h.DallList) + group.GET("/remove", h.Remove) + }), + fx.Provide(admin.NewMediaHandler), + fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) { + group := s.Engine.Group("/api/admin/media") + group.POST("/list/suno", h.SunoList) + group.POST("/list/luma", h.LumaList) + group.GET("/remove", h.Remove) }), ) // 启动应用程序 diff --git a/web/.env.development b/web/.env.development index 99693506..d563a103 100644 --- a/web/.env.development +++ b/web/.env.development @@ -6,6 +6,6 @@ VUE_APP_ADMIN_USER=admin VUE_APP_ADMIN_PASS=admin123 VUE_APP_KEY_PREFIX=GeekAI_DEV_ VUE_APP_TITLE="Geek-AI 创作系统" -VUE_APP_VERSION=v4.1.5 +VUE_APP_VERSION=v4.1.6 VUE_APP_DOCS_URL=https://docs.geekai.me VUE_APP_GIT_URL=https://github.com/yangjian102621/geekai diff --git a/web/.env.production b/web/.env.production index 362a2b93..d265a04a 100644 --- a/web/.env.production +++ b/web/.env.production @@ -1,6 +1,6 @@ VUE_APP_API_HOST= VUE_APP_WS_HOST= VUE_APP_KEY_PREFIX=GeekAI_ -VUE_APP_VERSION=v4.1.5 +VUE_APP_VERSION=v4.1.6 VUE_APP_DOCS_URL=https://docs.geekai.me VUE_APP_GIT_URL=https://github.com/yangjian102621/geekai diff --git a/web/src/assets/css/chat-plus.styl b/web/src/assets/css/chat-plus.styl index a2d73018..06b44eb5 100644 --- a/web/src/assets/css/chat-plus.styl +++ b/web/src/assets/css/chat-plus.styl @@ -14,7 +14,7 @@ $borderColor = #4676d0; padding 10px width var(--el-aside-width, 320px) - .chat-list { + .media-page { display: flex flex-flow: column //background-color: $sideBgColor diff --git a/web/src/components/admin/AdminSidebar.vue b/web/src/components/admin/AdminSidebar.vue index 61354001..357b5b12 100644 --- a/web/src/components/admin/AdminSidebar.vue +++ b/web/src/components/admin/AdminSidebar.vue @@ -142,6 +142,11 @@ const items = [ index: '/admin/images', title: '绘图管理', }, + { + icon: 'mp3', + index: '/admin/medias', + title: '音视频管理', + }, { icon: 'role', index: '/admin/manger', diff --git a/web/src/router.js b/web/src/router.js index 2ca79100..3d5fa603 100644 --- a/web/src/router.js +++ b/web/src/router.js @@ -239,6 +239,12 @@ const routes = [ meta: {title: '绘图管理'}, component: () => import('@/views/admin/ImageList.vue'), }, + { + path: '/admin/medias', + name: 'admin-medias', + meta: {title: '音视频管理'}, + component: () => import('@/views/admin/Medias.vue'), + }, { path: '/admin/powerLog', name: 'admin-power-log', diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 19cd964c..13ad0661 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -2,7 +2,7 @@
-
+
diff --git a/web/src/views/admin/ChatList.vue b/web/src/views/admin/ChatList.vue index 971cbd23..fc94afc8 100644 --- a/web/src/views/admin/ChatList.vue +++ b/web/src/views/admin/ChatList.vue @@ -1,5 +1,5 @@