mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-06 19:24:27 +08:00
即梦AI绘图功能前端页面完成
This commit is contained in:
@@ -38,7 +38,7 @@ func (h *AdminJimengHandler) RegisterRoutes() {
|
||||
rg.POST("/jobs/batch-remove", h.BatchRemove)
|
||||
rg.GET("/stats", h.Stats)
|
||||
rg.GET("/config", h.GetConfig)
|
||||
rg.POST("/config", h.UpdateConfig)
|
||||
rg.POST("/config/update", h.UpdateConfig)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
@@ -241,12 +241,6 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
testErr := h.jimengService.TestConnection(req.AccessKey, req.SecretKey)
|
||||
if testErr != nil {
|
||||
resp.ERROR(c, "连接测试失败: "+testErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证算力配置
|
||||
if req.Power.TextToImage <= 0 {
|
||||
resp.ERROR(c, "文生图算力必须大于0")
|
||||
@@ -274,10 +268,11 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存配置
|
||||
tx := h.DB.Begin()
|
||||
value := utils.JsonEncode(&req)
|
||||
config := model.Config{Name: "jimeng", Value: value}
|
||||
|
||||
err := h.DB.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error
|
||||
err := tx.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "保存配置失败: "+err.Error())
|
||||
return
|
||||
@@ -285,7 +280,7 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
|
||||
if config.Id > 0 {
|
||||
config.Value = value
|
||||
err = h.DB.Updates(&config).Error
|
||||
err = tx.Updates(&config).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新配置失败: "+err.Error())
|
||||
return
|
||||
@@ -295,9 +290,11 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
// 更新服务中的客户端配置
|
||||
updateErr := h.jimengService.UpdateClientConfig(req.AccessKey, req.SecretKey)
|
||||
if updateErr != nil {
|
||||
// 配置已保存,但客户端更新失败,记录日志但不返回错误
|
||||
logger.Errorf("更新即梦AI客户端配置失败: %v", updateErr)
|
||||
resp.ERROR(c, updateErr.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
|
||||
}
|
||||
|
||||
@@ -21,504 +21,230 @@ type JimengHandler struct {
|
||||
}
|
||||
|
||||
// NewJimengHandler 创建即梦AI处理器
|
||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service) *JimengHandler {
|
||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB) *JimengHandler {
|
||||
return &JimengHandler{
|
||||
BaseHandler: BaseHandler{App: app},
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由,新增统一任务接口
|
||||
func (h *JimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/jimeng")
|
||||
rg.POST("text-to-image", h.TextToImage)
|
||||
rg.POST("image-to-image-portrait", h.ImageToImagePortrait)
|
||||
rg.POST("image-edit", h.ImageEdit)
|
||||
rg.POST("image-effects", h.ImageEffects)
|
||||
rg.POST("text-to-video", h.TextToVideo)
|
||||
rg.POST("image-to-video", h.ImageToVideo)
|
||||
rg.POST("task", h.CreateTask) // 只保留统一任务接口
|
||||
rg.GET("power-config", h.GetPowerConfig) // 新增算力配置接口
|
||||
rg.GET("jobs", h.Jobs)
|
||||
rg.GET("pending-count", h.PendingCount)
|
||||
rg.GET("remove", h.Remove)
|
||||
rg.GET("retry", h.Retry)
|
||||
}
|
||||
|
||||
// TextToImage 文生图
|
||||
func (h *JimengHandler) TextToImage(c *gin.Context) {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
Seed int64 `json:"seed"`
|
||||
Scale float64 `json:"scale"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
UsePreLLM bool `json:"use_pre_llm"`
|
||||
}
|
||||
// JimengTaskRequest 统一任务请求结构体
|
||||
// 支持所有生图和生成视频类型
|
||||
type JimengTaskRequest struct {
|
||||
TaskType string `json:"task_type" binding:"required"`
|
||||
Prompt string `json:"prompt"`
|
||||
ImageInput string `json:"image_input"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Scale float64 `json:"scale"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Gpen float64 `json:"gpen"`
|
||||
Skin float64 `json:"skin"`
|
||||
SkinUnifi float64 `json:"skin_unifi"`
|
||||
GenMode string `json:"gen_mode"`
|
||||
Seed int64 `json:"seed"`
|
||||
UsePreLLM bool `json:"use_pre_llm"`
|
||||
TemplateId string `json:"template_id"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
// CreateTask 统一任务创建接口
|
||||
func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
var req JimengTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
// 新增:除图像特效外,其他任务类型必须有提示词
|
||||
if req.TaskType != "image_effects" && req.Prompt == "" {
|
||||
resp.ERROR(c, "提示词不能为空")
|
||||
return
|
||||
}
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取配置中的算力消耗
|
||||
powerCost := h.getPowerFromConfig(model.JMTaskTypeTextToImage)
|
||||
var powerCost int
|
||||
var taskType model.JMTaskType
|
||||
var params map[string]interface{}
|
||||
var reqKey string
|
||||
var modelName string
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 2.5
|
||||
}
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"use_pre_llm": req.UsePreLLM,
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JMTaskTypeTextToImage,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: jimeng.ReqKeyTextToImage,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng text to image task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦文生图",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// ImageToImagePortrait 图生图人像写真
|
||||
func (h *JimengHandler) ImageToImagePortrait(c *gin.Context) {
|
||||
var req struct {
|
||||
ImageInput string `json:"image_input" binding:"required"`
|
||||
Prompt string `json:"prompt"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Gpen float64 `json:"gpen"`
|
||||
Skin float64 `json:"skin"`
|
||||
SkinUnifi float64 `json:"skin_unifi"`
|
||||
GenMode string `json:"gen_mode"`
|
||||
Seed int64 `json:"seed"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取配置中的算力消耗
|
||||
powerCost := h.getPowerFromConfig(model.JMTaskTypeImageToImage)
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Gpen == 0 {
|
||||
req.Gpen = 0.4
|
||||
}
|
||||
if req.Skin == 0 {
|
||||
req.Skin = 0.3
|
||||
}
|
||||
if req.GenMode == "" {
|
||||
if req.Prompt != "" {
|
||||
req.GenMode = jimeng.GenModeCreative
|
||||
} else {
|
||||
req.GenMode = jimeng.GenModeReference
|
||||
switch req.TaskType {
|
||||
case "text_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToImage)
|
||||
taskType = model.JMTaskTypeTextToImage
|
||||
reqKey = jimeng.ReqKeyTextToImage
|
||||
modelName = "即梦文生图"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 2.5
|
||||
}
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
params = map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"use_pre_llm": req.UsePreLLM,
|
||||
}
|
||||
case "image_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToImage)
|
||||
taskType = model.JMTaskTypeImageToImage
|
||||
reqKey = jimeng.ReqKeyImageToImagePortrait
|
||||
modelName = "即梦图生图"
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Gpen == 0 {
|
||||
req.Gpen = 0.4
|
||||
}
|
||||
if req.Skin == 0 {
|
||||
req.Skin = 0.3
|
||||
}
|
||||
if req.GenMode == "" {
|
||||
if req.Prompt != "" {
|
||||
req.GenMode = jimeng.GenModeCreative
|
||||
} else {
|
||||
req.GenMode = jimeng.GenModeReference
|
||||
}
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
req.Prompt = "演唱会现场的合照,闪光灯拍摄"
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"image_input": req.ImageInput,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"gpen": req.Gpen,
|
||||
"skin": req.Skin,
|
||||
"skin_unifi": req.SkinUnifi,
|
||||
"gen_mode": req.GenMode,
|
||||
"seed": req.Seed,
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JMTaskTypeImageToImage,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: jimeng.ReqKeyImageToImagePortrait,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng image to image portrait task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
params = map[string]interface{}{
|
||||
"image_input": req.ImageInput,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"gpen": req.Gpen,
|
||||
"skin": req.Skin,
|
||||
"skin_unifi": req.SkinUnifi,
|
||||
"gen_mode": req.GenMode,
|
||||
"seed": req.Seed,
|
||||
}
|
||||
case "image_edit":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEdit)
|
||||
taskType = model.JMTaskTypeImageEdit
|
||||
reqKey = jimeng.ReqKeyImageEdit
|
||||
modelName = "即梦图像编辑"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 0.5
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
params = map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
case "image_effects":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
|
||||
taskType = model.JMTaskTypeImageEffects
|
||||
reqKey = jimeng.ReqKeyImageEffects
|
||||
modelName = "即梦图像特效"
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
params = map[string]interface{}{
|
||||
"image_input1": req.ImageInput,
|
||||
"template_id": req.TemplateId,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
}
|
||||
case "text_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToVideo)
|
||||
taskType = model.JMTaskTypeTextToVideo
|
||||
reqKey = jimeng.ReqKeyTextToVideo
|
||||
modelName = "即梦文生视频"
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
if req.AspectRatio == "" {
|
||||
req.AspectRatio = jimeng.AspectRatio16_9
|
||||
}
|
||||
params = map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
case "image_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToVideo)
|
||||
taskType = model.JMTaskTypeImageToVideo
|
||||
reqKey = jimeng.ReqKeyImageToVideo
|
||||
modelName = "即梦图生视频"
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
params = map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
default:
|
||||
resp.ERROR(c, "不支持的任务类型")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦图生图",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// ImageEdit 图像编辑
|
||||
func (h *JimengHandler) ImageEdit(c *gin.Context) {
|
||||
var req struct {
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
Seed int64 `json:"seed"`
|
||||
Scale float64 `json:"scale"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ImageUrls) == 0 && len(req.BinaryDataBase64) == 0 {
|
||||
resp.ERROR(c, "请提供图片URL或Base64数据")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取配置中的算力消耗
|
||||
powerCost := h.getPowerFromConfig(model.JMTaskTypeImageEdit)
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 0.5
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JMTaskTypeImageEdit,
|
||||
Type: taskType,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: jimeng.ReqKeyImageEdit,
|
||||
ReqKey: reqKey,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng image edit task failed: %v", err)
|
||||
logger.Errorf("create jimeng task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦图像编辑",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// ImageEffects 图像特效
|
||||
func (h *JimengHandler) ImageEffects(c *gin.Context) {
|
||||
var req struct {
|
||||
ImageInput1 string `json:"image_input1" binding:"required"`
|
||||
TemplateId string `json:"template_id" binding:"required"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取配置中的算力消耗
|
||||
powerCost := h.getPowerFromConfig(model.JMTaskTypeImageEffects)
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"image_input1": req.ImageInput1,
|
||||
"template_id": req.TemplateId,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JMTaskTypeImageEffects,
|
||||
Prompt: "",
|
||||
Params: params,
|
||||
ReqKey: jimeng.ReqKeyImageEffects,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng image effects task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦图像特效",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// TextToVideo 文生视频
|
||||
func (h *JimengHandler) TextToVideo(c *gin.Context) {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取配置中的算力消耗
|
||||
powerCost := h.getPowerFromConfig(model.JMTaskTypeTextToVideo)
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
if req.AspectRatio == "" {
|
||||
req.AspectRatio = jimeng.AspectRatio16_9
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JMTaskTypeTextToVideo,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: jimeng.ReqKeyTextToVideo,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng text to video task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦文生视频",
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// ImageToVideo 图生视频
|
||||
func (h *JimengHandler) ImageToVideo(c *gin.Context) {
|
||||
var req struct {
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Prompt string `json:"prompt"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ImageUrls) == 0 && len(req.BinaryDataBase64) == 0 {
|
||||
resp.ERROR(c, "请提供图片URL或Base64数据")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取配置中的算力消耗
|
||||
powerCost := h.getPowerFromConfig(model.JMTaskTypeImageToVideo)
|
||||
|
||||
// 检查用户算力
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
// 构建任务参数
|
||||
params := map[string]interface{}{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: model.JMTaskTypeImageToVideo,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: jimeng.ReqKeyImageToVideo,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng image to video task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 扣除用户算力
|
||||
h.subUserPower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "即梦图生视频",
|
||||
Model: modelName,
|
||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
||||
})
|
||||
|
||||
@@ -551,24 +277,6 @@ func (h *JimengHandler) Jobs(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// PendingCount 获取未完成任务数量
|
||||
func (h *JimengHandler) PendingCount(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.jimengService.GetPendingTaskCount(user.Id)
|
||||
if err != nil {
|
||||
logger.Errorf("get pending task count failed: %v", err)
|
||||
resp.ERROR(c, "获取待处理任务数量失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"count": count})
|
||||
}
|
||||
|
||||
// Remove 删除任务
|
||||
func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
@@ -583,6 +291,21 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取任务,判断状态
|
||||
job, err := h.jimengService.GetJob(uint(jobId))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
if job.UserId != user.Id {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能删除")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.jimengService.DeleteJob(uint(jobId), user.Id); err != nil {
|
||||
logger.Errorf("delete jimeng job failed: %v", err)
|
||||
resp.ERROR(c, "删除任务失败")
|
||||
@@ -709,3 +432,20 @@ func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
return 10
|
||||
}
|
||||
}
|
||||
|
||||
// GetPowerConfig 获取即梦各任务类型算力消耗配置
|
||||
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
||||
config, err := h.jimengService.GetConfig()
|
||||
if err != nil || config == nil {
|
||||
resp.ERROR(c, "获取算力配置失败")
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"text_to_image": config.Power.TextToImage,
|
||||
"image_to_image": config.Power.ImageToImage,
|
||||
"image_edit": config.Power.ImageEdit,
|
||||
"image_effects": config.Power.ImageEffects,
|
||||
"text_to_video": config.Power.TextToVideo,
|
||||
"image_to_video": config.Power.ImageToVideo,
|
||||
})
|
||||
}
|
||||
|
||||
13
api/main.go
13
api/main.go
@@ -208,21 +208,10 @@ func main() {
|
||||
}),
|
||||
|
||||
// 即梦AI 服务
|
||||
fx.Provide(func(config *types.AppConfig) *jimeng.Client {
|
||||
// 使用默认配置初始化客户端,后续会从数据库加载
|
||||
return jimeng.NewClient("", "")
|
||||
}),
|
||||
fx.Provide(jimeng.NewService),
|
||||
fx.Invoke(func(service *jimeng.Service) {
|
||||
// 从数据库加载配置
|
||||
err := service.LoadConfigFromDB()
|
||||
if err != nil {
|
||||
logger.Errorf("加载即梦AI配置失败: %v", err)
|
||||
}
|
||||
}),
|
||||
fx.Provide(jimeng.NewConsumer),
|
||||
fx.Invoke(func(consumer *jimeng.Consumer) {
|
||||
consumer.Start()
|
||||
//consumer.Start()
|
||||
go consumer.MonitorQueue()
|
||||
}),
|
||||
fx.Provide(service.NewUserService),
|
||||
|
||||
@@ -8,11 +8,8 @@ import (
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
"geekai/logger"
|
||||
)
|
||||
|
||||
var clientLogger = logger.GetLogger()
|
||||
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
visual *visual.Visual
|
||||
@@ -80,7 +77,7 @@ func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error)
|
||||
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||
looger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result SubmitTaskResponse
|
||||
@@ -105,7 +102,7 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||
looger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result QueryTaskResponse
|
||||
@@ -130,7 +127,7 @@ func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, err
|
||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
looger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应,同步任务直接返回结果
|
||||
var result QueryTaskResponse
|
||||
@@ -139,4 +136,4 @@ func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (c *Consumer) consume() {
|
||||
// processTask 处理任务
|
||||
func (c *Consumer) processTask() {
|
||||
// 从队列中获取任务
|
||||
var task map[string]interface{}
|
||||
var task map[string]any
|
||||
if err := c.service.taskQueue.LPop(&task); err != nil {
|
||||
// 队列为空,等待1秒后重试
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"geekai/logger"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var serviceLogger = logger.GetLogger()
|
||||
var looger = logger2.GetLogger()
|
||||
|
||||
// Service 即梦服务
|
||||
type Service struct {
|
||||
@@ -29,8 +30,16 @@ type Service struct {
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, client *Client) *Service {
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
// 从数据库加载配置
|
||||
var config model.Config
|
||||
db.Where("name = ?", "Jimeng").First(&config)
|
||||
var jimengConfig types.JimengConfig
|
||||
if config.Id > 0 {
|
||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
}
|
||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
||||
return &Service{
|
||||
db: db,
|
||||
redis: redisCli,
|
||||
@@ -99,7 +108,7 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
return s.processTextToImage(&job)
|
||||
case model.JMTaskTypeImageToImage:
|
||||
return s.processImageToImagePortrait(&job)
|
||||
return s.processImageToImage(&job)
|
||||
case model.JMTaskTypeImageEdit:
|
||||
return s.processImageEdit(&job)
|
||||
case model.JMTaskTypeImageEffects:
|
||||
@@ -171,15 +180,15 @@ func (s *Service) processTextToImage(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
}
|
||||
|
||||
// processImageToImagePortrait 处理图生图人像写真任务
|
||||
func (s *Service) processImageToImagePortrait(job *model.JimengJob) error {
|
||||
// processImageToImage 处理图生图任务
|
||||
func (s *Service) processImageToImage(job *model.JimengJob) error {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
@@ -249,7 +258,7 @@ func (s *Service) processImageToImagePortrait(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -315,7 +324,7 @@ func (s *Service) processImageEdit(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -370,7 +379,7 @@ func (s *Service) processImageEffects(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -418,7 +427,7 @@ func (s *Service) processTextToVideo(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -482,7 +491,7 @@ func (s *Service) processImageToVideo(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -505,7 +514,7 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
serviceLogger.Errorf("query jimeng task status failed: %v", err)
|
||||
looger.Errorf("query jimeng task status failed: %v", err)
|
||||
retryCount++
|
||||
continue
|
||||
}
|
||||
@@ -555,7 +564,7 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
||||
return s.handleTaskError(jobId, fmt.Sprintf("task not found or expired: %s", resp.Data.Status))
|
||||
|
||||
default:
|
||||
serviceLogger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
looger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
|
||||
retryCount++
|
||||
@@ -587,7 +596,7 @@ func (s *Service) UpdateJobProgress(jobId uint, progress int) error {
|
||||
|
||||
// handleTaskError 处理任务错误
|
||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||
serviceLogger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
looger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
||||
}
|
||||
|
||||
@@ -635,13 +644,12 @@ func (s *Service) DeleteJob(jobId uint, userId uint) error {
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列
|
||||
func (s *Service) PushTaskToQueue(task map[string]interface{}) error {
|
||||
func (s *Service) PushTaskToQueue(task map[string]any) error {
|
||||
return s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
// TestConnection 测试即梦AI连接
|
||||
func (s *Service) TestConnection(accessKey, secretKey string) error {
|
||||
// 创建临时客户端进行测试
|
||||
// testConnection 测试即梦AI连接
|
||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||
testClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
@@ -655,13 +663,12 @@ func (s *Service) TestConnection(accessKey, secretKey string) error {
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if err.Error() == "unauthorized" || err.Error() == "access denied" {
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -671,9 +678,9 @@ func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
|
||||
newClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 测试新客户端是否可用
|
||||
err := s.TestConnection(accessKey, secretKey)
|
||||
err := s.testConnection(accessKey, secretKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("新配置测试失败: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新客户端
|
||||
@@ -709,18 +716,3 @@ func (s *Service) GetConfig() (*types.JimengConfig, error) {
|
||||
|
||||
return &jimengConfig, nil
|
||||
}
|
||||
|
||||
// LoadConfigFromDB 从数据库加载配置并更新客户端
|
||||
func (s *Service) LoadConfigFromDB() error {
|
||||
config, err := s.GetConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果配置中有AccessKey和SecretKey,则更新客户端
|
||||
if config.AccessKey != "" && config.SecretKey != "" {
|
||||
return s.UpdateClientConfig(config.AccessKey, config.SecretKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user