mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-06 19:24:27 +08:00
增加即梦AI功能页面
This commit is contained in:
177
api/handler/admin/jimeng_handler.go
Normal file
177
api/handler/admin/jimeng_handler.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/handler"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminJimengHandler 管理后台即梦AI处理器
|
||||
type AdminJimengHandler struct {
|
||||
handler.BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
}
|
||||
|
||||
// NewAdminJimengHandler 创建管理后台即梦AI处理器
|
||||
func NewAdminJimengHandler(app *core.AppServer, jimengService *jimeng.Service) *AdminJimengHandler {
|
||||
return &AdminJimengHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app},
|
||||
jimengService: jimengService,
|
||||
}
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *AdminJimengHandler) Jobs(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
taskType := h.GetTrim(c, "type")
|
||||
status := h.GetTrim(c, "status")
|
||||
|
||||
var tasks []model.JimengJob
|
||||
var total int64
|
||||
|
||||
session := h.DB.Model(&model.JimengJob{})
|
||||
|
||||
// 构建查询条件
|
||||
if userId > 0 {
|
||||
session = session.Where("user_id = ?", userId)
|
||||
}
|
||||
if taskType != "" {
|
||||
session = session.Where("type = ?", taskType)
|
||||
}
|
||||
if status != "" {
|
||||
session = session.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := session.Count(&total).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取任务数量失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = session.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&tasks).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取任务列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"jobs": tasks,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// JobDetail 获取任务详情
|
||||
func (h *AdminJimengHandler) JobDetail(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
jobId, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
var job model.JimengJob
|
||||
err = h.DB.Where("id = ?", jobId).First(&job).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// Remove 删除任务
|
||||
func (h *AdminJimengHandler) Remove(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
jobId, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.DB.Where("id = ?", jobId).Delete(&model.JimengJob{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "删除任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{})
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除任务
|
||||
func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
|
||||
var req struct {
|
||||
JobIds []uint `json:"job_ids" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
result := h.DB.Where("id IN ?", req.JobIds).Delete(&model.JimengJob{})
|
||||
if result.Error != nil {
|
||||
resp.ERROR(c, "批量删除失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"message": "批量删除成功",
|
||||
"deleted_count": result.RowsAffected,
|
||||
})
|
||||
}
|
||||
|
||||
// Stats 获取统计信息
|
||||
func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
type StatResult struct {
|
||||
Status string `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
err := h.DB.Model(&model.JimengJob{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&stats).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取统计信息失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 整理统计数据
|
||||
result := gin.H{
|
||||
"totalTasks": int64(0),
|
||||
"completedTasks": int64(0),
|
||||
"processingTasks": int64(0),
|
||||
"failedTasks": int64(0),
|
||||
"pendingTasks": int64(0),
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
result["totalTasks"] = result["totalTasks"].(int64) + stat.Count
|
||||
switch stat.Status {
|
||||
case "completed":
|
||||
result["completedTasks"] = stat.Count
|
||||
case "processing":
|
||||
result["processingTasks"] = stat.Count
|
||||
case "failed":
|
||||
result["failedTasks"] = stat.Count
|
||||
case "pending":
|
||||
result["pendingTasks"] = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, result)
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
Power: chatModel.Power,
|
||||
}
|
||||
job := model.DallJob{
|
||||
|
||||
@@ -213,7 +213,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
Prompt: prompt,
|
||||
ModelId: 0,
|
||||
ModelName: "dall-e-3",
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
N: 1,
|
||||
Quality: "standard",
|
||||
Size: "1024x1024",
|
||||
@@ -265,27 +265,27 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 从参数中获取搜索关键词
|
||||
keyword, ok := params["keyword"].(string)
|
||||
if !ok || keyword == "" {
|
||||
resp.ERROR(c, "搜索关键词不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 从参数中获取最大页数,默认为1页
|
||||
maxPages := 1
|
||||
if pages, ok := params["max_pages"].(float64); ok {
|
||||
maxPages = int(pages)
|
||||
}
|
||||
|
||||
|
||||
// 获取用户ID
|
||||
userID, ok := params["user_id"].(float64)
|
||||
if !ok {
|
||||
resp.ERROR(c, "用户ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 查询用户信息
|
||||
var user model.User
|
||||
res := h.DB.Where("id = ?", int(userID)).First(&user)
|
||||
@@ -293,21 +293,21 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||||
resp.ERROR(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 检查用户算力是否足够
|
||||
searchPower := 1 // 每次搜索消耗1点算力
|
||||
if user.Power < searchPower {
|
||||
resp.ERROR(c, "算力不足,无法执行网络搜索")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 执行网络搜索
|
||||
searchResults, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 扣减用户算力
|
||||
err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
@@ -318,7 +318,7 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||||
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 返回搜索结果
|
||||
resp.SUCCESS(c, searchResults)
|
||||
}
|
||||
|
||||
639
api/handler/jimeng_handler.go
Normal file
639
api/handler/jimeng_handler.go
Normal file
@@ -0,0 +1,639 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// JimengHandler 即梦AI处理器
|
||||
type JimengHandler struct {
|
||||
BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
}
|
||||
|
||||
// NewJimengHandler 创建即梦AI处理器
|
||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service) *JimengHandler {
|
||||
return &JimengHandler{
|
||||
BaseHandler: BaseHandler{App: app},
|
||||
jimengService: jimengService,
|
||||
}
|
||||
}
|
||||
|
||||
// TextToImage 文生图
|
||||
func (h *JimengHandler) TextToImage(c *gin.Context) {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
Seed int64 `json:"seed"`
|
||||
Scale float64 `json:"scale"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
UsePreLLM bool `json:"use_pre_llm"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < 20 { // 文生图消耗20算力
|
||||
resp.ERROR(c, "算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 2.5
|
||||
}
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"use_pre_llm": req.UsePreLLM,
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JimengJobTypeTextToImage,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: model.ReqKeyTextToImage,
|
||||
Power: 20,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng text to image task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, 20, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦文生图",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// ImageToImagePortrait 图生图人像写真
|
||||
func (h *JimengHandler) ImageToImagePortrait(c *gin.Context) {
|
||||
var req struct {
|
||||
ImageInput string `json:"image_input" binding:"required"`
|
||||
Prompt string `json:"prompt"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Gpen float64 `json:"gpen"`
|
||||
Skin float64 `json:"skin"`
|
||||
SkinUnifi float64 `json:"skin_unifi"`
|
||||
GenMode string `json:"gen_mode"`
|
||||
Seed int64 `json:"seed"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < 30 { // 图生图消耗30算力
|
||||
resp.ERROR(c, "算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Gpen == 0 {
|
||||
req.Gpen = 0.4
|
||||
}
|
||||
if req.Skin == 0 {
|
||||
req.Skin = 0.3
|
||||
}
|
||||
if req.GenMode == "" {
|
||||
if req.Prompt != "" {
|
||||
req.GenMode = jimeng.GenModeCreative
|
||||
} else {
|
||||
req.GenMode = jimeng.GenModeReference
|
||||
}
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
req.Prompt = "演唱会现场的合照,闪光灯拍摄"
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"image_input": req.ImageInput,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"gpen": req.Gpen,
|
||||
"skin": req.Skin,
|
||||
"skin_unifi": req.SkinUnifi,
|
||||
"gen_mode": req.GenMode,
|
||||
"seed": req.Seed,
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JimengJobTypeImageToImagePortrait,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: model.ReqKeyImageToImagePortrait,
|
||||
Power: 30,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng image to image portrait task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, 30, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦图生图",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// ImageEdit 图像编辑
|
||||
func (h *JimengHandler) ImageEdit(c *gin.Context) {
|
||||
var req struct {
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
Seed int64 `json:"seed"`
|
||||
Scale float64 `json:"scale"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ImageUrls) == 0 && len(req.BinaryDataBase64) == 0 {
|
||||
resp.ERROR(c, "请提供图片URL或Base64数据")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < 25 { // 图像编辑消耗25算力
|
||||
resp.ERROR(c, "算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 0.5
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JimengJobTypeImageEdit,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: model.ReqKeyImageEdit,
|
||||
Power: 25,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng image edit task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, 25, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦图像编辑",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// ImageEffects 图像特效
|
||||
func (h *JimengHandler) ImageEffects(c *gin.Context) {
|
||||
var req struct {
|
||||
ImageInput1 string `json:"image_input1" binding:"required"`
|
||||
TemplateId string `json:"template_id" binding:"required"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < 15 { // 图像特效消耗15算力
|
||||
resp.ERROR(c, "算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"image_input1": req.ImageInput1,
|
||||
"template_id": req.TemplateId,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JimengJobTypeImageEffects,
|
||||
Prompt: "",
|
||||
Params: params,
|
||||
ReqKey: model.ReqKeyImageEffects,
|
||||
Power: 15,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng image effects task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, 15, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦图像特效",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// TextToVideo 文生视频
|
||||
func (h *JimengHandler) TextToVideo(c *gin.Context) {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < 100 { // 文生视频消耗100算力
|
||||
resp.ERROR(c, "算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
if req.AspectRatio == "" {
|
||||
req.AspectRatio = jimeng.AspectRatio16_9
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JimengJobTypeTextToVideo,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: model.ReqKeyTextToVideo,
|
||||
Power: 100,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng text to video task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, 100, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦文生视频",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// ImageToVideo 图生视频
|
||||
func (h *JimengHandler) ImageToVideo(c *gin.Context) {
|
||||
var req struct {
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Prompt string `json:"prompt"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ImageUrls) == 0 && len(req.BinaryDataBase64) == 0 {
|
||||
resp.ERROR(c, "请提供图片URL或Base64数据")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < 120 { // 图生视频消耗120算力
|
||||
resp.ERROR(c, "算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JimengJobTypeImageToVideo,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: model.ReqKeyImageToVideo,
|
||||
Power: 120,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng image to video task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, 120, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦图生视频",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *JimengHandler) Jobs(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
|
||||
jobs, total, err := h.jimengService.GetUserJobs(user.Id, page, pageSize)
|
||||
if err != nil {
|
||||
logger.Errorf("get user jimeng jobs failed: %v", err)
|
||||
resp.ERROR(c, "获取任务列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"jobs": jobs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// PendingCount 获取未完成任务数量
|
||||
func (h *JimengHandler) PendingCount(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.jimengService.GetPendingTaskCount(user.Id)
|
||||
if err != nil {
|
||||
logger.Errorf("get pending task count failed: %v", err)
|
||||
resp.ERROR(c, "获取待处理任务数量失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"count": count})
|
||||
}
|
||||
|
||||
// Remove 删除任务
|
||||
func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
jobId := h.GetInt(c, "id", 0)
|
||||
if jobId == 0 {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.jimengService.DeleteJob(uint(jobId), user.Id); err != nil {
|
||||
logger.Errorf("delete jimeng job failed: %v", err)
|
||||
resp.ERROR(c, "删除任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{})
|
||||
}
|
||||
|
||||
// Retry 重试任务
|
||||
func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
jobIdStr := c.Param("id")
|
||||
jobId, err := strconv.ParseUint(jobIdStr, 10, 32)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查任务是否存在且属于当前用户
|
||||
job, err := h.jimengService.GetJob(uint(jobId))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if job.UserId != user.Id {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
|
||||
// 只有失败的任务才能重试
|
||||
if job.Status != model.JimengJobStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能重试")
|
||||
return
|
||||
}
|
||||
|
||||
// 重置任务状态
|
||||
if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JimengJobStatusPending, ""); err != nil {
|
||||
logger.Errorf("reset job status failed: %v", err)
|
||||
resp.ERROR(c, "重置任务状态失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 重新推送到队列
|
||||
task := map[string]interface{}{
|
||||
"job_id": jobId,
|
||||
"type": job.Type,
|
||||
}
|
||||
if err := h.jimengService.PushTaskToQueue(task); err != nil {
|
||||
logger.Errorf("push retry task to queue failed: %v", err)
|
||||
resp.ERROR(c, "推送重试任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "重试任务已提交"})
|
||||
}
|
||||
|
||||
// subUserPower 扣除用户算力
|
||||
func (h *JimengHandler) subUserPower(userId uint, power int, powerLog model.PowerLog) {
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
|
||||
// 更新用户算力
|
||||
if err := session.Model(&model.User{}).Where("id = ?", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error; err != nil {
|
||||
logger.Errorf("update user power failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录算力消费日志
|
||||
powerLog.UserId = userId
|
||||
powerLog.Amount = power
|
||||
powerLog.Mark = types.PowerSub
|
||||
powerLog.CreatedAt = time.Now()
|
||||
if err := session.Create(&powerLog).Error; err != nil {
|
||||
logger.Errorf("create power log failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
session.Commit()
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
UserId: userId,
|
||||
ImgArr: data.ImgArr,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: data.TaskType,
|
||||
|
||||
@@ -48,7 +48,7 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -79,7 +79,7 @@ func (h *PromptHandler) Image(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -108,7 +108,7 @@ func (h *PromptHandler) Video(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -158,9 +158,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *PromptHandler) getPromptModel() string {
|
||||
if h.App.SysConfig.TranslateModelId > 0 {
|
||||
if h.App.SysConfig.AssistantModelId > 0 {
|
||||
var chatModel model.ChatModel
|
||||
h.DB.Where("id", h.App.SysConfig.TranslateModelId).First(&chatModel)
|
||||
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel)
|
||||
return chatModel.Value
|
||||
}
|
||||
return "gpt-4o"
|
||||
|
||||
@@ -131,7 +131,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
HdSteps: data.HdSteps,
|
||||
},
|
||||
UserId: userId,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
|
||||
job := model.SdJob{
|
||||
|
||||
@@ -85,7 +85,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
// 插入数据库
|
||||
job := model.VideoJob{
|
||||
@@ -181,7 +181,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
|
||||
Type: types.VideoKeLing,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
Channel: data.Channel,
|
||||
}
|
||||
// 插入数据库
|
||||
|
||||
Reference in New Issue
Block a user