mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
255 lines
6.6 KiB
Go
255 lines
6.6 KiB
Go
package admin
|
||
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||
// * Use of this source code is governed by a Apache-2.0 license
|
||
// * that can be found in the LICENSE file.
|
||
// * @Author yangjian102621@163.com
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
||
import (
|
||
"fmt"
|
||
"geekai/core"
|
||
"geekai/core/types"
|
||
"geekai/handler"
|
||
"geekai/service"
|
||
"geekai/service/oss"
|
||
"geekai/store/model"
|
||
"geekai/store/vo"
|
||
"geekai/utils"
|
||
"geekai/utils/resp"
|
||
"github.com/gin-gonic/gin"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type ImageHandler struct {
|
||
handler.BaseHandler
|
||
userService *service.UserService
|
||
uploader *oss.UploaderManager
|
||
}
|
||
|
||
func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler {
|
||
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||
}
|
||
|
||
type imageQuery struct {
|
||
Prompt string `json:"prompt"`
|
||
Username string `json:"username"`
|
||
CreatedAt []string `json:"created_at"`
|
||
Page int `json:"page"`
|
||
PageSize int `json:"page_size"`
|
||
}
|
||
|
||
// MjList Midjourney 任务列表
|
||
func (h *ImageHandler) MjList(c *gin.Context) {
|
||
var data imageQuery
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
session := h.DB.Session(&gorm.Session{})
|
||
if data.Username != "" {
|
||
var user model.User
|
||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||
if err == nil {
|
||
session = session.Where("user_id", user.Id)
|
||
}
|
||
}
|
||
if data.Prompt != "" {
|
||
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||
}
|
||
if len(data.CreatedAt) == 2 {
|
||
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||
}
|
||
var total int64
|
||
session.Model(&model.MidJourneyJob{}).Count(&total)
|
||
var list []model.MidJourneyJob
|
||
var items = make([]vo.MidJourneyJob, 0)
|
||
offset := (data.Page - 1) * data.PageSize
|
||
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||
if err == nil {
|
||
// 填充数据
|
||
for _, item := range list {
|
||
var job vo.MidJourneyJob
|
||
err = utils.CopyObject(item, &job)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
job.CreatedAt = item.CreatedAt.Unix()
|
||
items = append(items, job)
|
||
}
|
||
}
|
||
|
||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||
}
|
||
|
||
// SdList Stable Diffusion 任务列表
|
||
func (h *ImageHandler) SdList(c *gin.Context) {
|
||
var data imageQuery
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
session := h.DB.Session(&gorm.Session{})
|
||
if data.Username != "" {
|
||
var user model.User
|
||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||
if err == nil {
|
||
session = session.Where("user_id", user.Id)
|
||
}
|
||
}
|
||
if data.Prompt != "" {
|
||
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||
}
|
||
if len(data.CreatedAt) == 2 {
|
||
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||
}
|
||
var total int64
|
||
session.Model(&model.SdJob{}).Count(&total)
|
||
var list []model.SdJob
|
||
var items = make([]vo.SdJob, 0)
|
||
offset := (data.Page - 1) * data.PageSize
|
||
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||
if err == nil {
|
||
// 填充数据
|
||
for _, item := range list {
|
||
var job vo.SdJob
|
||
err = utils.CopyObject(item, &job)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
job.CreatedAt = item.CreatedAt.Unix()
|
||
items = append(items, job)
|
||
}
|
||
}
|
||
|
||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||
}
|
||
|
||
// DallList DALL-E 任务列表
|
||
func (h *ImageHandler) DallList(c *gin.Context) {
|
||
var data imageQuery
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
session := h.DB.Session(&gorm.Session{})
|
||
if data.Username != "" {
|
||
var user model.User
|
||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||
if err == nil {
|
||
session = session.Where("user_id", user.Id)
|
||
}
|
||
}
|
||
if data.Prompt != "" {
|
||
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||
}
|
||
if len(data.CreatedAt) == 2 {
|
||
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||
}
|
||
var total int64
|
||
session.Model(&model.DallJob{}).Count(&total)
|
||
var list []model.DallJob
|
||
var items = make([]vo.DallJob, 0)
|
||
offset := (data.Page - 1) * data.PageSize
|
||
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||
if err == nil {
|
||
// 填充数据
|
||
for _, item := range list {
|
||
var job vo.DallJob
|
||
err = utils.CopyObject(item, &job)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
job.CreatedAt = item.CreatedAt.Unix()
|
||
items = append(items, job)
|
||
}
|
||
}
|
||
|
||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||
}
|
||
|
||
func (h *ImageHandler) Remove(c *gin.Context) {
|
||
id := h.GetInt(c, "id", 0)
|
||
tab := c.Query("tab")
|
||
|
||
tx := h.DB.Begin()
|
||
var md, remark, imgURL string
|
||
var power, userId, progress int
|
||
switch tab {
|
||
case "mj":
|
||
var job model.MidJourneyJob
|
||
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
|
||
resp.ERROR(c, "记录不存在")
|
||
return
|
||
}
|
||
tx.Delete(&job)
|
||
md = "mid-journey"
|
||
power = job.Power
|
||
userId = job.UserId
|
||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||
progress = job.Progress
|
||
imgURL = job.ImgURL
|
||
break
|
||
case "sd":
|
||
var job model.SdJob
|
||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||
resp.ERROR(c, "记录不存在")
|
||
return
|
||
}
|
||
|
||
// 删除任务
|
||
tx.Delete(&job)
|
||
md = "stable-diffusion"
|
||
power = job.Power
|
||
userId = job.UserId
|
||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||
progress = job.Progress
|
||
imgURL = job.ImgURL
|
||
break
|
||
case "dall":
|
||
var job model.DallJob
|
||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||
resp.ERROR(c, "记录不存在")
|
||
return
|
||
}
|
||
|
||
// 删除任务
|
||
tx.Delete(&job)
|
||
md = "dall-e-3"
|
||
power = job.Power
|
||
userId = int(job.UserId)
|
||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||
progress = job.Progress
|
||
imgURL = job.ImgURL
|
||
break
|
||
default:
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
if progress != 100 {
|
||
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||
Type: types.PowerRefund,
|
||
Model: md,
|
||
Remark: remark,
|
||
})
|
||
if err != nil {
|
||
tx.Rollback()
|
||
resp.ERROR(c, err.Error())
|
||
return
|
||
}
|
||
}
|
||
tx.Commit()
|
||
// remove image
|
||
err := h.uploader.GetUploadHandler().Delete(imgURL)
|
||
if err != nil {
|
||
logger.Error("remove image failed: ", err)
|
||
}
|
||
|
||
resp.SUCCESS(c)
|
||
}
|