mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-06 11:14:24 +08:00
即梦 AI 管理后台功能完成
This commit is contained in:
@@ -194,7 +194,6 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "sd":
|
||||
var job model.SdJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
@@ -210,7 +209,6 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "dall":
|
||||
var job model.DallJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
@@ -226,7 +224,6 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
@@ -19,13 +22,17 @@ import (
|
||||
type AdminJimengHandler struct {
|
||||
handler.BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewAdminJimengHandler 创建管理后台即梦AI处理器
|
||||
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service) *AdminJimengHandler {
|
||||
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
|
||||
return &AdminJimengHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,8 +41,7 @@ func (h *AdminJimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/admin/jimeng/")
|
||||
rg.GET("/jobs", h.Jobs)
|
||||
rg.GET("/jobs/:id", h.JobDetail)
|
||||
rg.DELETE("/jobs/:id", h.Remove)
|
||||
rg.POST("/jobs/batch-remove", h.BatchRemove)
|
||||
rg.POST("/jobs/remove", h.BatchRemove)
|
||||
rg.GET("/stats", h.Stats)
|
||||
rg.GET("/config", h.GetConfig)
|
||||
rg.POST("/config/update", h.UpdateConfig)
|
||||
@@ -107,24 +113,6 @@ func (h *AdminJimengHandler) JobDetail(c *gin.Context) {
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// Remove 删除任务
|
||||
func (h *AdminJimengHandler) Remove(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
jobId, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.DB.Where("id = ?", jobId).Delete(&model.JimengJob{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "删除任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{})
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除任务
|
||||
func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
|
||||
var req struct {
|
||||
@@ -136,23 +124,57 @@ func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result := h.DB.Where("id IN ?", req.JobIds).Delete(&model.JimengJob{})
|
||||
if result.Error != nil {
|
||||
resp.ERROR(c, "批量删除失败")
|
||||
return
|
||||
var deletedCount int64 = 0
|
||||
for _, jobId := range req.JobIds {
|
||||
var job model.JimengJob
|
||||
err := h.DB.Where("id = ?", jobId).First(&job).Error
|
||||
if err != nil {
|
||||
continue // 跳过不存在的
|
||||
}
|
||||
tx := h.DB.Begin()
|
||||
if job.Status != model.JMTaskStatusSuccess && job.Power > 0 {
|
||||
remark := fmt.Sprintf("任务未成功,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: remark,
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
}
|
||||
err = tx.Where("id = ?", jobId).Delete(&model.JimengJob{}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
tx.Commit()
|
||||
deletedCount++
|
||||
if job.ImgURL != "" {
|
||||
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
}
|
||||
if job.VideoURL != "" {
|
||||
err = h.uploader.GetUploadHandler().Delete(job.VideoURL)
|
||||
if err != nil {
|
||||
logger.Error("remove video failed: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"message": "批量删除成功",
|
||||
"deleted_count": result.RowsAffected,
|
||||
"deleted_count": deletedCount,
|
||||
})
|
||||
}
|
||||
|
||||
// Stats 获取统计信息
|
||||
func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
type StatResult struct {
|
||||
Status string `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
@@ -177,14 +199,14 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
for _, stat := range stats {
|
||||
result["totalTasks"] = result["totalTasks"].(int64) + stat.Count
|
||||
switch stat.Status {
|
||||
case "completed":
|
||||
result["completedTasks"] = stat.Count
|
||||
case "processing":
|
||||
result["processingTasks"] = stat.Count
|
||||
case "failed":
|
||||
result["failedTasks"] = stat.Count
|
||||
case "pending":
|
||||
case model.JMTaskStatusInQueue:
|
||||
result["pendingTasks"] = stat.Count
|
||||
case model.JMTaskStatusSuccess:
|
||||
result["completedTasks"] = stat.Count
|
||||
case model.JMTaskStatusGenerating:
|
||||
result["processingTasks"] = stat.Count
|
||||
case model.JMTaskStatusFailed:
|
||||
result["failedTasks"] = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,33 +215,7 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (h *AdminJimengHandler) GetConfig(c *gin.Context) {
|
||||
var config model.Config
|
||||
err := h.DB.Debug().Where("name", "jimeng").First(&config).Error
|
||||
if err != nil {
|
||||
// 如果配置不存在,返回默认配置
|
||||
defaultConfig := types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: types.JimengPower{
|
||||
TextToImage: 10,
|
||||
ImageToImage: 15,
|
||||
ImageEdit: 20,
|
||||
ImageEffects: 25,
|
||||
TextToVideo: 30,
|
||||
ImageToVideo: 35,
|
||||
},
|
||||
}
|
||||
resp.SUCCESS(c, defaultConfig)
|
||||
return
|
||||
}
|
||||
|
||||
var jimengConfig types.JimengConfig
|
||||
err = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "解析配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jimengConfig := h.jimengService.GetConfig()
|
||||
resp.SUCCESS(c, jimengConfig)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"geekai/core/types"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -33,7 +35,7 @@ func (h *JimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/jimeng")
|
||||
rg.POST("task", h.CreateTask) // 只保留统一任务接口
|
||||
rg.GET("power-config", h.GetPowerConfig) // 新增算力配置接口
|
||||
rg.GET("jobs", h.Jobs)
|
||||
rg.POST("jobs", h.Jobs)
|
||||
rg.GET("remove", h.Remove)
|
||||
rg.GET("retry", h.Retry)
|
||||
}
|
||||
@@ -253,28 +255,66 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *JimengHandler) Jobs(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
var req struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Filter string `json:"filter"`
|
||||
Ids []uint `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
var jobs []model.JimengJob
|
||||
var total int64
|
||||
query := h.DB.Model(&model.JimengJob{}).Where("user_id = ?", userId)
|
||||
if req.Filter == "image" {
|
||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||
model.JMTaskTypeTextToImage,
|
||||
model.JMTaskTypeImageToImage,
|
||||
model.JMTaskTypeImageEdit,
|
||||
model.JMTaskTypeImageEffects,
|
||||
})
|
||||
} else if req.Filter == "video" {
|
||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||
model.JMTaskTypeTextToVideo,
|
||||
model.JMTaskTypeImageToVideo,
|
||||
})
|
||||
}
|
||||
|
||||
jobs, total, err := h.jimengService.GetUserJobs(user.Id, page, pageSize)
|
||||
if err != nil {
|
||||
logger.Errorf("get user jimeng jobs failed: %v", err)
|
||||
resp.ERROR(c, "获取任务列表失败")
|
||||
if len(req.Ids) > 0 {
|
||||
query = query.Where("id IN (?)", req.Ids)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"jobs": jobs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
// 分页查询
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(req.PageSize).Find(&jobs).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 填充 VO
|
||||
var jobVos []vo.JimengJob
|
||||
for _, job := range jobs {
|
||||
var jobVo vo.JimengJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
jobVo.CreatedAt = job.CreatedAt.Unix()
|
||||
jobVos = append(jobVos, jobVo)
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, req.Page, req.PageSize, jobVos))
|
||||
}
|
||||
|
||||
// Remove 删除任务
|
||||
@@ -355,7 +395,7 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 重新推送到队列
|
||||
task := map[string]interface{}{
|
||||
task := map[string]any{
|
||||
"job_id": jobId,
|
||||
"type": job.Type,
|
||||
}
|
||||
@@ -393,27 +433,7 @@ func (h *JimengHandler) subUserPower(userId uint, power int, powerLog model.Powe
|
||||
|
||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
config, err := h.jimengService.GetConfig()
|
||||
if err != nil {
|
||||
logger.Errorf("获取即梦AI配置失败: %v", err)
|
||||
// 返回默认值
|
||||
switch taskType {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
return 10
|
||||
case model.JMTaskTypeImageToImage:
|
||||
return 15
|
||||
case model.JMTaskTypeImageEdit:
|
||||
return 20
|
||||
case model.JMTaskTypeImageEffects:
|
||||
return 25
|
||||
case model.JMTaskTypeTextToVideo:
|
||||
return 30
|
||||
case model.JMTaskTypeImageToVideo:
|
||||
return 35
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
}
|
||||
config := h.jimengService.GetConfig()
|
||||
|
||||
switch taskType {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
@@ -435,11 +455,7 @@ func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
|
||||
// GetPowerConfig 获取即梦各任务类型算力消耗配置
|
||||
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
||||
config, err := h.jimengService.GetConfig()
|
||||
if err != nil || config == nil {
|
||||
resp.ERROR(c, "获取算力配置失败")
|
||||
return
|
||||
}
|
||||
config := h.jimengService.GetConfig()
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"text_to_image": config.Power.TextToImage,
|
||||
"image_to_image": config.Power.ImageToImage,
|
||||
|
||||
Reference in New Issue
Block a user