From 54fe49de5d311f7ba3855abb39b18176e576aa18 Mon Sep 17 00:00:00 2001 From: GeekMaster Date: Tue, 22 Jul 2025 16:46:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=88=E5=B9=B6=E6=9C=8D=E5=8A=A1=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/handler/jimeng_handler.go | 93 ++---- api/main.go | 6 +- api/service/jimeng/client.go | 6 +- api/service/jimeng/consumer.go | 177 ----------- api/service/jimeng/service.go | 527 ++++++++++++--------------------- 5 files changed, 223 insertions(+), 586 deletions(-) delete mode 100644 api/service/jimeng/consumer.go diff --git a/api/handler/jimeng_handler.go b/api/handler/jimeng_handler.go index cd8a5ef7..4590356c 100644 --- a/api/handler/jimeng_handler.go +++ b/api/handler/jimeng_handler.go @@ -2,10 +2,10 @@ package handler import ( "fmt" - "time" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/service/jimeng" "geekai/store/model" "geekai/store/vo" @@ -20,13 +20,15 @@ import ( type JimengHandler struct { BaseHandler jimengService *jimeng.Service + userService *service.UserService } // NewJimengHandler 创建即梦AI处理器 -func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB) *JimengHandler { +func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService) *JimengHandler { return &JimengHandler{ BaseHandler: BaseHandler{App: app, DB: db}, jimengService: jimengService, + userService: userService, } } @@ -79,9 +81,19 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { return } + if req.Width == 0 { + req.Width = 1328 + } + if req.Height == 0 { + req.Height = 1328 + } + if req.Seed == 0 { + req.Seed = -1 + } + var powerCost int var taskType model.JMTaskType - var params map[string]interface{} + var params map[string]any var reqKey string var modelName string @@ -94,16 +106,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { if req.Scale == 0 { req.Scale = 2.5 } - if req.Width == 0 { - req.Width = 1328 - } - if req.Height == 0 { - req.Height = 1328 - } - if req.Seed == 0 { - req.Seed = -1 - } - params = map[string]interface{}{ + params = map[string]any{ "seed": req.Seed, "scale": req.Scale, "width": req.Width, @@ -115,12 +118,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { taskType = model.JMTaskTypeImageToImage reqKey = jimeng.ReqKeyImageToImagePortrait modelName = "即梦图生图" - if req.Width == 0 { - req.Width = 1328 - } - if req.Height == 0 { - req.Height = 1328 - } if req.Gpen == 0 { req.Gpen = 0.4 } @@ -134,11 +131,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { req.GenMode = jimeng.GenModeReference } } - if req.Seed == 0 { - req.Seed = -1 - } - - params = map[string]interface{}{ + params = map[string]any{ "image_input": req.ImageInput, "width": req.Width, "height": req.Height, @@ -156,10 +149,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { if req.Scale == 0 { req.Scale = 0.5 } - if req.Seed == 0 { - req.Seed = -1 - } - params = map[string]interface{}{ + params = map[string]any{ "seed": req.Seed, "scale": req.Scale, } @@ -180,7 +170,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { if req.Height == 0 { req.Height = 1328 } - params = map[string]interface{}{ + params = map[string]any{ "image_input1": req.ImageInput, "template_id": req.TemplateId, "width": req.Width, @@ -197,7 +187,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { if req.AspectRatio == "" { req.AspectRatio = jimeng.AspectRatio16_9 } - params = map[string]interface{}{ + params = map[string]any{ "seed": req.Seed, "aspect_ratio": req.AspectRatio, } @@ -209,7 +199,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { if req.Seed == 0 { req.Seed = -1 } - params = map[string]interface{}{ + params = map[string]any{ "seed": req.Seed, "aspect_ratio": req.AspectRatio, } @@ -244,10 +234,10 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { return } - h.subUserPower(user.Id, powerCost, model.PowerLog{ + h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{ Type: types.PowerConsume, - Model: modelName, - Remark: fmt.Sprintf("任务ID:%d", job.Id), + Model: "jimeng", + Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id), }) resp.SUCCESS(c, job) @@ -272,14 +262,16 @@ func (h *JimengHandler) Jobs(c *gin.Context) { var jobs []model.JimengJob var total int64 query := h.DB.Model(&model.JimengJob{}).Where("user_id = ?", userId) - if req.Filter == "image" { + + switch req.Filter { + case "image": query = query.Where("type IN (?)", []model.JMTaskType{ model.JMTaskTypeTextToImage, model.JMTaskTypeImageToImage, model.JMTaskTypeImageEdit, model.JMTaskTypeImageEffects, }) - } else if req.Filter == "video" { + case "video": query = query.Where("type IN (?)", []model.JMTaskType{ model.JMTaskTypeTextToVideo, model.JMTaskTypeImageToVideo, @@ -395,11 +387,7 @@ func (h *JimengHandler) Retry(c *gin.Context) { } // 重新推送到队列 - task := map[string]any{ - "job_id": jobId, - "type": job.Type, - } - if err := h.jimengService.PushTaskToQueue(task); err != nil { + if err := h.jimengService.PushTaskToQueue(uint(jobId)); err != nil { logger.Errorf("push retry task to queue failed: %v", err) resp.ERROR(c, "推送重试任务失败") return @@ -408,29 +396,6 @@ func (h *JimengHandler) Retry(c *gin.Context) { resp.SUCCESS(c, gin.H{"message": "重试任务已提交"}) } -// subUserPower 扣除用户算力 -func (h *JimengHandler) subUserPower(userId uint, power int, powerLog model.PowerLog) { - session := h.DB.Session(&gorm.Session{}) - - // 更新用户算力 - if err := session.Model(&model.User{}).Where("id = ?", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error; err != nil { - logger.Errorf("update user power failed: %v", err) - return - } - - // 记录算力消费日志 - powerLog.UserId = userId - powerLog.Amount = power - powerLog.Mark = types.PowerSub - powerLog.CreatedAt = time.Now() - if err := session.Create(&powerLog).Error; err != nil { - logger.Errorf("create power log failed: %v", err) - return - } - - session.Commit() -} - // getPowerFromConfig 从配置中获取指定类型的算力消耗 func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int { config := h.jimengService.GetConfig() diff --git a/api/main.go b/api/main.go index 6dc7ce50..7ddd1a6e 100644 --- a/api/main.go +++ b/api/main.go @@ -209,10 +209,8 @@ func main() { // 即梦AI 服务 fx.Provide(jimeng.NewService), - fx.Provide(jimeng.NewConsumer), - fx.Invoke(func(consumer *jimeng.Consumer) { - //consumer.Start() - go consumer.MonitorQueue() + fx.Invoke(func(service *jimeng.Service) { + service.Start() }), fx.Provide(service.NewUserService), fx.Provide(payment.NewAlipayService), diff --git a/api/service/jimeng/client.go b/api/service/jimeng/client.go index 0f15104b..24fa0126 100644 --- a/api/service/jimeng/client.go +++ b/api/service/jimeng/client.go @@ -77,7 +77,7 @@ func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err) } - looger.Infof("Jimeng SubmitTask Response: %s", string(respBody)) + logger.Infof("Jimeng SubmitTask Response: %s", string(respBody)) // 解析响应 var result SubmitTaskResponse @@ -102,7 +102,7 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) { return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err) } - looger.Infof("Jimeng QueryTask Response: %s", string(respBody)) + logger.Infof("Jimeng QueryTask Response: %s", string(respBody)) // 解析响应 var result QueryTaskResponse @@ -127,7 +127,7 @@ func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, err return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err) } - looger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody)) + logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody)) // 解析响应,同步任务直接返回结果 var result QueryTaskResponse diff --git a/api/service/jimeng/consumer.go b/api/service/jimeng/consumer.go deleted file mode 100644 index c42eff1a..00000000 --- a/api/service/jimeng/consumer.go +++ /dev/null @@ -1,177 +0,0 @@ -package jimeng - -import ( - "context" - "time" - - "geekai/logger" - "geekai/store/model" -) - -var jimengLogger = logger.GetLogger() - -// Consumer 即梦任务消费者 -type Consumer struct { - service *Service - ctx context.Context - cancel context.CancelFunc -} - -// NewConsumer 创建即梦任务消费者 -func NewConsumer(service *Service) *Consumer { - ctx, cancel := context.WithCancel(context.Background()) - return &Consumer{ - service: service, - ctx: ctx, - cancel: cancel, - } -} - -// Start 启动消费者 -func (c *Consumer) Start() { - jimengLogger.Info("Starting Jimeng task consumer...") - go c.consume() -} - -// Stop 停止消费者 -func (c *Consumer) Stop() { - jimengLogger.Info("Stopping Jimeng task consumer...") - c.cancel() -} - -// consume 消费任务 -func (c *Consumer) consume() { - for { - select { - case <-c.ctx.Done(): - jimengLogger.Info("Jimeng task consumer stopped") - return - default: - c.processTask() - } - } -} - -// processTask 处理任务 -func (c *Consumer) processTask() { - // 从队列中获取任务 - var task map[string]any - if err := c.service.taskQueue.LPop(&task); err != nil { - // 队列为空,等待1秒后重试 - time.Sleep(time.Second) - return - } - - // 解析任务 - jobIdFloat, ok := task["job_id"].(float64) - if !ok { - jimengLogger.Errorf("invalid job_id in task: %v", task) - return - } - jobId := uint(jobIdFloat) - - taskType, ok := task["type"].(string) - if !ok { - jimengLogger.Errorf("invalid task type in task: %v", task) - return - } - - jimengLogger.Infof("Processing Jimeng task: job_id=%d, type=%s", jobId, taskType) - - // 处理任务 - if err := c.service.ProcessTask(jobId); err != nil { - jimengLogger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err) - - // 任务失败,直接标记为失败状态,不进行重试 - c.service.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error()) - } else { - jimengLogger.Infof("Jimeng task processed successfully: job_id=%d", jobId) - } -} - -// TaskQueueStatus 任务队列状态 -type TaskQueueStatus struct { - QueueLength int `json:"queue_length"` - ActiveTasks int `json:"active_tasks"` -} - -// GetQueueStatus 获取队列状态 -func (c *Consumer) GetQueueStatus() (*TaskQueueStatus, error) { - // 获取队列长度 - length, err := c.service.taskQueue.Size() - if err != nil { - return nil, err - } - - // 获取活跃任务数(正在处理的任务) - activeTasks, err := c.service.GetPendingTaskCount(0) // 0表示所有用户 - if err != nil { - activeTasks = 0 - } - - return &TaskQueueStatus{ - QueueLength: int(length), - ActiveTasks: int(activeTasks), - }, nil -} - -// MonitorQueue 监控队列状态 -func (c *Consumer) MonitorQueue() { - ticker := time.NewTicker(30 * time.Second) // 每30秒监控一次 - defer ticker.Stop() - - for { - select { - case <-c.ctx.Done(): - return - case <-ticker.C: - status, err := c.GetQueueStatus() - if err != nil { - jimengLogger.Errorf("get queue status failed: %v", err) - continue - } - - if status.QueueLength > 0 || status.ActiveTasks > 0 { - jimengLogger.Infof("Jimeng queue status: queue_length=%d, active_tasks=%d", - status.QueueLength, status.ActiveTasks) - } - } - } -} - -// PushTaskToQueue 推送任务到队列(用于手动重试) -func (c *Consumer) PushTaskToQueue(task map[string]interface{}) error { - return c.service.taskQueue.RPush(task) -} - -// GetTaskStats 获取任务统计信息 -func (c *Consumer) GetTaskStats() (map[string]any, error) { - type StatResult struct { - Status string `json:"status"` - Count int64 `json:"count"` - } - - var stats []StatResult - err := c.service.db.Model(&model.JimengJob{}). - Select("status, COUNT(*) as count"). - Group("status"). - Find(&stats).Error - if err != nil { - return nil, err - } - - result := map[string]any{ - "total": int64(0), - "completed": int64(0), - "processing": int64(0), - "failed": int64(0), - "pending": int64(0), - } - - for _, stat := range stats { - result["total"] = result["total"].(int64) + stat.Count - result[stat.Status] = stat.Count - } - - return result, nil -} diff --git a/api/service/jimeng/service.go b/api/service/jimeng/service.go index 69ef0a2c..2fbd3624 100644 --- a/api/service/jimeng/service.go +++ b/api/service/jimeng/service.go @@ -1,6 +1,7 @@ package jimeng import ( + "context" "encoding/json" "fmt" "strconv" @@ -19,14 +20,17 @@ import ( "github.com/go-redis/redis/v8" ) -var looger = logger2.GetLogger() +var logger = logger2.GetLogger() -// Service 即梦服务 +// Service 即梦服务(合并了消费者功能) type Service struct { db *gorm.DB redis *redis.Client taskQueue *store.RedisQueue client *Client + ctx context.Context + cancel context.CancelFunc + running bool } // NewService 创建即梦服务 @@ -40,11 +44,68 @@ func NewService(db *gorm.DB, redisCli *redis.Client) *Service { _ = utils.JsonDecode(config.Value, &jimengConfig) } client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey) + + ctx, cancel := context.WithCancel(context.Background()) return &Service{ db: db, redis: redisCli, taskQueue: taskQueue, client: client, + ctx: ctx, + cancel: cancel, + running: false, + } +} + +// Start 启动服务(包含消费者) +func (s *Service) Start() { + if s.running { + return + } + logger.Info("Starting Jimeng service and task consumer...") + s.running = true + go s.consumeTasks() +} + +// Stop 停止服务 +func (s *Service) Stop() { + if !s.running { + return + } + logger.Info("Stopping Jimeng service and task consumer...") + s.running = false + s.cancel() +} + +// consumeTasks 消费任务 +func (s *Service) consumeTasks() { + for { + select { + case <-s.ctx.Done(): + logger.Info("Jimeng task consumer stopped") + return + default: + s.processNextTask() + } + } +} + +// processNextTask 处理下一个任务 +func (s *Service) processNextTask() { + var jobId uint + if err := s.taskQueue.LPop(&jobId); err != nil { + // 队列为空,等待1秒后重试 + time.Sleep(time.Second) + return + } + + logger.Infof("Processing Jimeng task: job_id=%d", jobId) + + if err := s.ProcessTask(jobId); err != nil { + logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err) + s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error()) + } else { + logger.Infof("Jimeng task processed successfully: job_id=%d", jobId) } } @@ -79,11 +140,7 @@ func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.Jimeng } // 推送到任务队列 - task := map[string]any{ - "job_id": job.Id, - "type": job.Type, - } - if err := s.taskQueue.RPush(task); err != nil { + if err := s.taskQueue.RPush(job.Id); err != nil { return nil, fmt.Errorf("push jimeng task to queue failed: %w", err) } @@ -103,40 +160,73 @@ func (s *Service) ProcessTask(jobId uint) error { return fmt.Errorf("update job status failed: %w", err) } - // 根据任务类型处理 - switch job.Type { - case model.JMTaskTypeTextToImage: - return s.processTextToImage(&job) - case model.JMTaskTypeImageToImage: - return s.processImageToImage(&job) - case model.JMTaskTypeImageEdit: - return s.processImageEdit(&job) - case model.JMTaskTypeImageEffects: - return s.processImageEffects(&job) - case model.JMTaskTypeTextToVideo: - return s.processTextToVideo(&job) - case model.JMTaskTypeImageToVideo: - return s.processImageToVideo(&job) - default: - return fmt.Errorf("unsupported task type: %s", job.Type) + // 构建请求并提交任务 + req, err := s.buildTaskRequest(&job) + if err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err)) } + + // 提交异步任务 + resp, err := s.client.SubmitTask(req) + if err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) + } + + if resp.Code != 10000 { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) + } + + // 更新任务ID和原始数据 + rawData, _ := json.Marshal(resp) + if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ + "task_id": resp.Data.TaskId, + "raw_data": string(rawData), + "updated_at": time.Now(), + }).Error; err != nil { + logger.Errorf("update jimeng job task_id failed: %v", err) + } + + // 开始轮询任务状态 + return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) } -// processTextToImage 处理文生图任务 -func (s *Service) processTextToImage(job *model.JimengJob) error { +// buildTaskRequest 构建任务请求(统一的参数解析) +func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) { // 解析任务参数 var params map[string]any if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) + return nil, fmt.Errorf("parse task params failed: %w", err) } - // 构建请求 + // 构建基础请求 req := &SubmitTaskRequest{ ReqKey: job.ReqKey, Prompt: job.Prompt, } - // 设置参数 + // 根据任务类型设置特定参数 + switch job.Type { + case model.JMTaskTypeTextToImage: + s.setTextToImageParams(req, params) + case model.JMTaskTypeImageToImage: + s.setImageToImageParams(req, params) + case model.JMTaskTypeImageEdit: + s.setImageEditParams(req, params) + case model.JMTaskTypeImageEffects: + s.setImageEffectsParams(req, params) + case model.JMTaskTypeTextToVideo: + s.setTextToVideoParams(req, params) + case model.JMTaskTypeImageToVideo: + s.setImageToVideoParams(req, params) + default: + return nil, fmt.Errorf("unsupported task type: %s", job.Type) + } + + return req, nil +} + +// setTextToImageParams 设置文生图参数 +func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) { if seed, ok := params["seed"]; ok { if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { req.Seed = seedVal @@ -162,51 +252,13 @@ func (s *Service) processTextToImage(job *model.JimengJob) error { req.UsePreLLM = usePreLlmVal } } - - // 提交异步任务 - resp, err := s.client.SubmitTask(req) - if err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) - } - - if resp.Code != 10000 { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) - } - - // 更新任务ID和原始数据 - rawData, _ := json.Marshal(resp) - if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ - "task_id": resp.Data.TaskId, - "raw_data": string(rawData), - "updated_at": time.Now(), - }).Error; err != nil { - looger.Errorf("update jimeng job task_id failed: %v", err) - } - - // 开始轮询任务状态 - return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) } -// processImageToImage 处理图生图任务 -func (s *Service) processImageToImage(job *model.JimengJob) error { - // 解析任务参数 - var params map[string]any - if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) - } - - // 构建请求 - req := &SubmitTaskRequest{ - ReqKey: job.ReqKey, - Prompt: job.Prompt, - } - - // 设置图像输入 +// setImageToImageParams 设置图生图参数 +func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) { if imageInput, ok := params["image_input"].(string); ok { req.ImageInput = imageInput } - - // 设置其他参数 if gpen, ok := params["gpen"]; ok { if gpenVal, ok := gpen.(float64); ok { req.Gpen = gpenVal @@ -225,61 +277,11 @@ func (s *Service) processImageToImage(job *model.JimengJob) error { if genMode, ok := params["gen_mode"].(string); ok { req.GenMode = genMode } - if width, ok := params["width"]; ok { - if widthVal, ok := width.(float64); ok { - req.Width = int(widthVal) - } - } - if height, ok := params["height"]; ok { - if heightVal, ok := height.(float64); ok { - req.Height = int(heightVal) - } - } - if seed, ok := params["seed"]; ok { - if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { - req.Seed = seedVal - } - } - - // 提交异步任务 - resp, err := s.client.SubmitTask(req) - if err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) - } - - if resp.Code != 10000 { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) - } - - // 更新任务ID和原始数据 - rawData, _ := json.Marshal(resp) - if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ - "task_id": resp.Data.TaskId, - "raw_data": string(rawData), - "updated_at": time.Now(), - }).Error; err != nil { - looger.Errorf("update jimeng job task_id failed: %v", err) - } - - // 开始轮询任务状态 - return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) + s.setCommonParams(req, params) // 复用通用参数 } -// processImageEdit 处理图像编辑任务 -func (s *Service) processImageEdit(job *model.JimengJob) error { - // 解析任务参数 - var params map[string]any - if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) - } - - // 构建请求 - req := &SubmitTaskRequest{ - ReqKey: job.ReqKey, - Prompt: job.Prompt, - } - - // 设置图像输入 +// setImageEditParams 设置图像编辑参数 +func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) { if imageUrls, ok := params["image_urls"].([]any); ok { for _, url := range imageUrls { if urlStr, ok := url.(string); ok { @@ -294,57 +296,16 @@ func (s *Service) processImageEdit(job *model.JimengJob) error { } } } - - // 设置其他参数 - if seed, ok := params["seed"]; ok { - if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { - req.Seed = seedVal - } - } if scale, ok := params["scale"]; ok { if scaleVal, ok := scale.(float64); ok { req.Scale = scaleVal } } - - // 提交异步任务 - resp, err := s.client.SubmitTask(req) - if err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) - } - - if resp.Code != 10000 { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) - } - - // 更新任务ID和原始数据 - rawData, _ := json.Marshal(resp) - if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ - "task_id": resp.Data.TaskId, - "raw_data": string(rawData), - "updated_at": time.Now(), - }).Error; err != nil { - looger.Errorf("update jimeng job task_id failed: %v", err) - } - - // 开始轮询任务状态 - return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) + s.setCommonParams(req, params) } -// processImageEffects 处理图像特效任务 -func (s *Service) processImageEffects(job *model.JimengJob) error { - // 解析任务参数 - var params map[string]any - if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) - } - - // 构建请求 - req := &SubmitTaskRequest{ - ReqKey: job.ReqKey, - } - - // 设置图像输入 +// setImageEffectsParams 设置图像特效参数 +func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) { if imageInput1, ok := params["image_input1"].(string); ok { req.ImageInput1 = imageInput1 } @@ -361,141 +322,41 @@ func (s *Service) processImageEffects(job *model.JimengJob) error { req.Height = int(heightVal) } } - - // 提交异步任务 - resp, err := s.client.SubmitTask(req) - if err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) - } - - if resp.Code != 10000 { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) - } - - // 更新任务ID和原始数据 - rawData, _ := json.Marshal(resp) - if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ - "task_id": resp.Data.TaskId, - "raw_data": string(rawData), - "updated_at": time.Now(), - }).Error; err != nil { - looger.Errorf("update jimeng job task_id failed: %v", err) - } - - // 开始轮询任务状态 - return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) } -// processTextToVideo 处理文生视频任务 -func (s *Service) processTextToVideo(job *model.JimengJob) error { - // 解析任务参数 - var params map[string]any - if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) +// setTextToVideoParams 设置文生视频参数 +func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) { + if aspectRatio, ok := params["aspect_ratio"].(string); ok { + req.AspectRatio = aspectRatio } + s.setCommonParams(req, params) +} - // 构建请求 - req := &SubmitTaskRequest{ - ReqKey: job.ReqKey, - Prompt: job.Prompt, +// setImageToVideoParams 设置图生视频参数 +func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) { + s.setImageEditParams(req, params) // 复用图像编辑的参数设置 + if aspectRatio, ok := params["aspect_ratio"].(string); ok { + req.AspectRatio = aspectRatio } +} - // 设置参数 +// setCommonParams 设置通用参数(seed, width, height等) +func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) { if seed, ok := params["seed"]; ok { if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { req.Seed = seedVal } } - if aspectRatio, ok := params["aspect_ratio"].(string); ok { - req.AspectRatio = aspectRatio - } - - // 提交异步任务 - resp, err := s.client.SubmitTask(req) - if err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) - } - - if resp.Code != 10000 { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) - } - - // 更新任务ID和原始数据 - rawData, _ := json.Marshal(resp) - if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ - "task_id": resp.Data.TaskId, - "raw_data": string(rawData), - "updated_at": time.Now(), - }).Error; err != nil { - looger.Errorf("update jimeng job task_id failed: %v", err) - } - - // 开始轮询任务状态 - return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) -} - -// processImageToVideo 处理图生视频任务 -func (s *Service) processImageToVideo(job *model.JimengJob) error { - // 解析任务参数 - var params map[string]any - if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) - } - - // 构建请求 - req := &SubmitTaskRequest{ - ReqKey: job.ReqKey, - Prompt: job.Prompt, - } - - // 设置图像输入 - if imageUrls, ok := params["image_urls"].([]any); ok { - for _, url := range imageUrls { - if urlStr, ok := url.(string); ok { - req.ImageUrls = append(req.ImageUrls, urlStr) - } + if width, ok := params["width"]; ok { + if widthVal, ok := width.(float64); ok { + req.Width = int(widthVal) } } - if binaryData, ok := params["binary_data_base64"].([]any); ok { - for _, data := range binaryData { - if dataStr, ok := data.(string); ok { - req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr) - } + if height, ok := params["height"]; ok { + if heightVal, ok := height.(float64); ok { + req.Height = int(heightVal) } } - - // 设置其他参数 - if seed, ok := params["seed"]; ok { - if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { - req.Seed = seedVal - } - } - if aspectRatio, ok := params["aspect_ratio"].(string); ok { - req.AspectRatio = aspectRatio - } - - // 提交异步任务 - resp, err := s.client.SubmitTask(req) - if err != nil { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) - } - - if resp.Code != 10000 { - return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) - } - - // 更新任务ID和原始数据 - rawData, _ := json.Marshal(resp) - if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ - "task_id": resp.Data.TaskId, - "raw_data": string(rawData), - "updated_at": time.Now(), - }).Error; err != nil { - looger.Errorf("update jimeng job task_id failed: %v", err) - } - - // 开始轮询任务状态 - return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) } // pollTaskStatus 轮询任务状态 @@ -514,7 +375,7 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error { }) if err != nil { - looger.Errorf("query jimeng task status failed: %v", err) + logger.Errorf("query jimeng task status failed: %v", err) retryCount++ continue } @@ -537,7 +398,6 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error { // 任务完成,更新结果 updates := map[string]any{ "status": model.JMTaskStatusSuccess, - "progress": 100, "updated_at": time.Now(), } @@ -553,18 +413,18 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error { case model.JMTaskStatusInQueue: // 任务在队列中 - s.UpdateJobProgress(jobId, 10) + s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "") case model.JMTaskStatusGenerating: // 任务处理中 - s.UpdateJobProgress(jobId, 50) + s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "") case model.JMTaskStatusNotFound: // 任务未找到或已过期 - return s.handleTaskError(jobId, fmt.Sprintf("task not found or expired: %s", resp.Data.Status)) + return s.handleTaskError(jobId, resp.Message) default: - looger.Warnf("unknown task status: %s", resp.Data.Status) + logger.Warnf("unknown task status: %s", resp.Data.Status) } retryCount++ @@ -586,20 +446,49 @@ func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error } -// UpdateJobProgress 更新任务进度 -func (s *Service) UpdateJobProgress(jobId uint, progress int) error { - return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(map[string]any{ - "progress": progress, - "updated_at": time.Now(), - }).Error -} - // handleTaskError 处理任务错误 func (s *Service) handleTaskError(jobId uint, errMsg string) error { - looger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg) + logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg) return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg) } +// PushTaskToQueue 推送任务到队列(用于手动重试) +func (s *Service) PushTaskToQueue(jobId uint) error { + return s.taskQueue.RPush(jobId) +} + +// GetTaskStats 获取任务统计信息 +func (s *Service) GetTaskStats() (map[string]any, error) { + type StatResult struct { + Status string `json:"status"` + Count int64 `json:"count"` + } + + var stats []StatResult + err := s.db.Model(&model.JimengJob{}). + Select("status, COUNT(*) as count"). + Group("status"). + Find(&stats).Error + if err != nil { + return nil, err + } + + result := map[string]any{ + "total": int64(0), + "completed": int64(0), + "processing": int64(0), + "failed": int64(0), + "pending": int64(0), + } + + for _, stat := range stats { + result["total"] = result["total"].(int64) + stat.Count + result[stat.Status] = stat.Count + } + + return result, nil +} + // GetJob 获取任务 func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) { var job model.JimengJob @@ -609,54 +498,16 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) { return &job, nil } -// GetJobByPage 分页获取任务列表 -func (s *Service) GetJobByPage(userId uint, page, pageSize int) ([]*model.JimengJob, int64, error) { - var jobs []*model.JimengJob - var total int64 - - query := s.db.Model(&model.JimengJob{}) - if userId > 0 { - query = query.Where("user_id = ?", userId) - } - - // 统计总数 - if err := query.Count(&total).Error; err != nil { - return nil, 0, err - } - - // 分页查询 - offset := (page - 1) * pageSize - if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&jobs).Error; err != nil { - return nil, 0, err - } - - return jobs, total, nil -} - -// GetPendingTaskCount 获取用户未完成任务数量 -func (s *Service) GetPendingTaskCount(userId uint) (int64, error) { - var count int64 - err := s.db.Model(&model.JimengJob{}).Where("user_id = ? AND status IN (?)", userId, - []model.JMTaskStatus{model.JMTaskStatusInQueue, model.JMTaskStatusGenerating}).Count(&count).Error - return count, err -} - // DeleteJob 删除任务 func (s *Service) DeleteJob(jobId uint, userId uint) error { return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error } -// PushTaskToQueue 推送任务到队列 -func (s *Service) PushTaskToQueue(task map[string]any) error { - return s.taskQueue.RPush(task) -} - // testConnection 测试即梦AI连接 func (s *Service) testConnection(accessKey, secretKey string) error { testClient := NewClient(accessKey, secretKey) // 使用一个简单的查询任务来测试连接 - // 这里使用一个不存在的任务ID来测试API连接是否正常 testReq := &QueryTaskRequest{ ReqKey: "test_connection", TaskId: "test_task_id_12345",