mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-05-05 09:24:29 +08:00
合并服务代码
This commit is contained in:
@@ -2,10 +2,10 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
"geekai/service/jimeng"
|
"geekai/service/jimeng"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/store/vo"
|
"geekai/store/vo"
|
||||||
@@ -20,13 +20,15 @@ import (
|
|||||||
type JimengHandler struct {
|
type JimengHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
jimengService *jimeng.Service
|
jimengService *jimeng.Service
|
||||||
|
userService *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJimengHandler 创建即梦AI处理器
|
// NewJimengHandler 创建即梦AI处理器
|
||||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB) *JimengHandler {
|
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService) *JimengHandler {
|
||||||
return &JimengHandler{
|
return &JimengHandler{
|
||||||
BaseHandler: BaseHandler{App: app, DB: db},
|
BaseHandler: BaseHandler{App: app, DB: db},
|
||||||
jimengService: jimengService,
|
jimengService: jimengService,
|
||||||
|
userService: userService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,9 +81,19 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Width == 0 {
|
||||||
|
req.Width = 1328
|
||||||
|
}
|
||||||
|
if req.Height == 0 {
|
||||||
|
req.Height = 1328
|
||||||
|
}
|
||||||
|
if req.Seed == 0 {
|
||||||
|
req.Seed = -1
|
||||||
|
}
|
||||||
|
|
||||||
var powerCost int
|
var powerCost int
|
||||||
var taskType model.JMTaskType
|
var taskType model.JMTaskType
|
||||||
var params map[string]interface{}
|
var params map[string]any
|
||||||
var reqKey string
|
var reqKey string
|
||||||
var modelName string
|
var modelName string
|
||||||
|
|
||||||
@@ -94,16 +106,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
if req.Scale == 0 {
|
if req.Scale == 0 {
|
||||||
req.Scale = 2.5
|
req.Scale = 2.5
|
||||||
}
|
}
|
||||||
if req.Width == 0 {
|
params = map[string]any{
|
||||||
req.Width = 1328
|
|
||||||
}
|
|
||||||
if req.Height == 0 {
|
|
||||||
req.Height = 1328
|
|
||||||
}
|
|
||||||
if req.Seed == 0 {
|
|
||||||
req.Seed = -1
|
|
||||||
}
|
|
||||||
params = map[string]interface{}{
|
|
||||||
"seed": req.Seed,
|
"seed": req.Seed,
|
||||||
"scale": req.Scale,
|
"scale": req.Scale,
|
||||||
"width": req.Width,
|
"width": req.Width,
|
||||||
@@ -115,12 +118,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
taskType = model.JMTaskTypeImageToImage
|
taskType = model.JMTaskTypeImageToImage
|
||||||
reqKey = jimeng.ReqKeyImageToImagePortrait
|
reqKey = jimeng.ReqKeyImageToImagePortrait
|
||||||
modelName = "即梦图生图"
|
modelName = "即梦图生图"
|
||||||
if req.Width == 0 {
|
|
||||||
req.Width = 1328
|
|
||||||
}
|
|
||||||
if req.Height == 0 {
|
|
||||||
req.Height = 1328
|
|
||||||
}
|
|
||||||
if req.Gpen == 0 {
|
if req.Gpen == 0 {
|
||||||
req.Gpen = 0.4
|
req.Gpen = 0.4
|
||||||
}
|
}
|
||||||
@@ -134,11 +131,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
req.GenMode = jimeng.GenModeReference
|
req.GenMode = jimeng.GenModeReference
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.Seed == 0 {
|
params = map[string]any{
|
||||||
req.Seed = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
params = map[string]interface{}{
|
|
||||||
"image_input": req.ImageInput,
|
"image_input": req.ImageInput,
|
||||||
"width": req.Width,
|
"width": req.Width,
|
||||||
"height": req.Height,
|
"height": req.Height,
|
||||||
@@ -156,10 +149,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
if req.Scale == 0 {
|
if req.Scale == 0 {
|
||||||
req.Scale = 0.5
|
req.Scale = 0.5
|
||||||
}
|
}
|
||||||
if req.Seed == 0 {
|
params = map[string]any{
|
||||||
req.Seed = -1
|
|
||||||
}
|
|
||||||
params = map[string]interface{}{
|
|
||||||
"seed": req.Seed,
|
"seed": req.Seed,
|
||||||
"scale": req.Scale,
|
"scale": req.Scale,
|
||||||
}
|
}
|
||||||
@@ -180,7 +170,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
if req.Height == 0 {
|
if req.Height == 0 {
|
||||||
req.Height = 1328
|
req.Height = 1328
|
||||||
}
|
}
|
||||||
params = map[string]interface{}{
|
params = map[string]any{
|
||||||
"image_input1": req.ImageInput,
|
"image_input1": req.ImageInput,
|
||||||
"template_id": req.TemplateId,
|
"template_id": req.TemplateId,
|
||||||
"width": req.Width,
|
"width": req.Width,
|
||||||
@@ -197,7 +187,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
if req.AspectRatio == "" {
|
if req.AspectRatio == "" {
|
||||||
req.AspectRatio = jimeng.AspectRatio16_9
|
req.AspectRatio = jimeng.AspectRatio16_9
|
||||||
}
|
}
|
||||||
params = map[string]interface{}{
|
params = map[string]any{
|
||||||
"seed": req.Seed,
|
"seed": req.Seed,
|
||||||
"aspect_ratio": req.AspectRatio,
|
"aspect_ratio": req.AspectRatio,
|
||||||
}
|
}
|
||||||
@@ -209,7 +199,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
if req.Seed == 0 {
|
if req.Seed == 0 {
|
||||||
req.Seed = -1
|
req.Seed = -1
|
||||||
}
|
}
|
||||||
params = map[string]interface{}{
|
params = map[string]any{
|
||||||
"seed": req.Seed,
|
"seed": req.Seed,
|
||||||
"aspect_ratio": req.AspectRatio,
|
"aspect_ratio": req.AspectRatio,
|
||||||
}
|
}
|
||||||
@@ -244,10 +234,10 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.subUserPower(user.Id, powerCost, model.PowerLog{
|
h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
|
||||||
Type: types.PowerConsume,
|
Type: types.PowerConsume,
|
||||||
Model: modelName,
|
Model: "jimeng",
|
||||||
Remark: fmt.Sprintf("任务ID:%d", job.Id),
|
Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp.SUCCESS(c, job)
|
resp.SUCCESS(c, job)
|
||||||
@@ -272,14 +262,16 @@ func (h *JimengHandler) Jobs(c *gin.Context) {
|
|||||||
var jobs []model.JimengJob
|
var jobs []model.JimengJob
|
||||||
var total int64
|
var total int64
|
||||||
query := h.DB.Model(&model.JimengJob{}).Where("user_id = ?", userId)
|
query := h.DB.Model(&model.JimengJob{}).Where("user_id = ?", userId)
|
||||||
if req.Filter == "image" {
|
|
||||||
|
switch req.Filter {
|
||||||
|
case "image":
|
||||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||||
model.JMTaskTypeTextToImage,
|
model.JMTaskTypeTextToImage,
|
||||||
model.JMTaskTypeImageToImage,
|
model.JMTaskTypeImageToImage,
|
||||||
model.JMTaskTypeImageEdit,
|
model.JMTaskTypeImageEdit,
|
||||||
model.JMTaskTypeImageEffects,
|
model.JMTaskTypeImageEffects,
|
||||||
})
|
})
|
||||||
} else if req.Filter == "video" {
|
case "video":
|
||||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||||
model.JMTaskTypeTextToVideo,
|
model.JMTaskTypeTextToVideo,
|
||||||
model.JMTaskTypeImageToVideo,
|
model.JMTaskTypeImageToVideo,
|
||||||
@@ -395,11 +387,7 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 重新推送到队列
|
// 重新推送到队列
|
||||||
task := map[string]any{
|
if err := h.jimengService.PushTaskToQueue(uint(jobId)); err != nil {
|
||||||
"job_id": jobId,
|
|
||||||
"type": job.Type,
|
|
||||||
}
|
|
||||||
if err := h.jimengService.PushTaskToQueue(task); err != nil {
|
|
||||||
logger.Errorf("push retry task to queue failed: %v", err)
|
logger.Errorf("push retry task to queue failed: %v", err)
|
||||||
resp.ERROR(c, "推送重试任务失败")
|
resp.ERROR(c, "推送重试任务失败")
|
||||||
return
|
return
|
||||||
@@ -408,29 +396,6 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, gin.H{"message": "重试任务已提交"})
|
resp.SUCCESS(c, gin.H{"message": "重试任务已提交"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// subUserPower 扣除用户算力
|
|
||||||
func (h *JimengHandler) subUserPower(userId uint, power int, powerLog model.PowerLog) {
|
|
||||||
session := h.DB.Session(&gorm.Session{})
|
|
||||||
|
|
||||||
// 更新用户算力
|
|
||||||
if err := session.Model(&model.User{}).Where("id = ?", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error; err != nil {
|
|
||||||
logger.Errorf("update user power failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 记录算力消费日志
|
|
||||||
powerLog.UserId = userId
|
|
||||||
powerLog.Amount = power
|
|
||||||
powerLog.Mark = types.PowerSub
|
|
||||||
powerLog.CreatedAt = time.Now()
|
|
||||||
if err := session.Create(&powerLog).Error; err != nil {
|
|
||||||
logger.Errorf("create power log failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
session.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||||
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||||
config := h.jimengService.GetConfig()
|
config := h.jimengService.GetConfig()
|
||||||
|
|||||||
@@ -209,10 +209,8 @@ func main() {
|
|||||||
|
|
||||||
// 即梦AI 服务
|
// 即梦AI 服务
|
||||||
fx.Provide(jimeng.NewService),
|
fx.Provide(jimeng.NewService),
|
||||||
fx.Provide(jimeng.NewConsumer),
|
fx.Invoke(func(service *jimeng.Service) {
|
||||||
fx.Invoke(func(consumer *jimeng.Consumer) {
|
service.Start()
|
||||||
//consumer.Start()
|
|
||||||
go consumer.MonitorQueue()
|
|
||||||
}),
|
}),
|
||||||
fx.Provide(service.NewUserService),
|
fx.Provide(service.NewUserService),
|
||||||
fx.Provide(payment.NewAlipayService),
|
fx.Provide(payment.NewAlipayService),
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error)
|
|||||||
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
looger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
logger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||||
|
|
||||||
// 解析响应
|
// 解析响应
|
||||||
var result SubmitTaskResponse
|
var result SubmitTaskResponse
|
||||||
@@ -102,7 +102,7 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
|||||||
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
|
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
looger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
logger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||||
|
|
||||||
// 解析响应
|
// 解析响应
|
||||||
var result QueryTaskResponse
|
var result QueryTaskResponse
|
||||||
@@ -127,7 +127,7 @@ func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, err
|
|||||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
looger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||||
|
|
||||||
// 解析响应,同步任务直接返回结果
|
// 解析响应,同步任务直接返回结果
|
||||||
var result QueryTaskResponse
|
var result QueryTaskResponse
|
||||||
|
|||||||
@@ -1,177 +0,0 @@
|
|||||||
package jimeng
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"geekai/logger"
|
|
||||||
"geekai/store/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
var jimengLogger = logger.GetLogger()
|
|
||||||
|
|
||||||
// Consumer 即梦任务消费者
|
|
||||||
type Consumer struct {
|
|
||||||
service *Service
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewConsumer 创建即梦任务消费者
|
|
||||||
func NewConsumer(service *Service) *Consumer {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
return &Consumer{
|
|
||||||
service: service,
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start 启动消费者
|
|
||||||
func (c *Consumer) Start() {
|
|
||||||
jimengLogger.Info("Starting Jimeng task consumer...")
|
|
||||||
go c.consume()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop 停止消费者
|
|
||||||
func (c *Consumer) Stop() {
|
|
||||||
jimengLogger.Info("Stopping Jimeng task consumer...")
|
|
||||||
c.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// consume 消费任务
|
|
||||||
func (c *Consumer) consume() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
jimengLogger.Info("Jimeng task consumer stopped")
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
c.processTask()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// processTask 处理任务
|
|
||||||
func (c *Consumer) processTask() {
|
|
||||||
// 从队列中获取任务
|
|
||||||
var task map[string]any
|
|
||||||
if err := c.service.taskQueue.LPop(&task); err != nil {
|
|
||||||
// 队列为空,等待1秒后重试
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析任务
|
|
||||||
jobIdFloat, ok := task["job_id"].(float64)
|
|
||||||
if !ok {
|
|
||||||
jimengLogger.Errorf("invalid job_id in task: %v", task)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
jobId := uint(jobIdFloat)
|
|
||||||
|
|
||||||
taskType, ok := task["type"].(string)
|
|
||||||
if !ok {
|
|
||||||
jimengLogger.Errorf("invalid task type in task: %v", task)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jimengLogger.Infof("Processing Jimeng task: job_id=%d, type=%s", jobId, taskType)
|
|
||||||
|
|
||||||
// 处理任务
|
|
||||||
if err := c.service.ProcessTask(jobId); err != nil {
|
|
||||||
jimengLogger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
|
|
||||||
|
|
||||||
// 任务失败,直接标记为失败状态,不进行重试
|
|
||||||
c.service.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
|
|
||||||
} else {
|
|
||||||
jimengLogger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TaskQueueStatus 任务队列状态
|
|
||||||
type TaskQueueStatus struct {
|
|
||||||
QueueLength int `json:"queue_length"`
|
|
||||||
ActiveTasks int `json:"active_tasks"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetQueueStatus 获取队列状态
|
|
||||||
func (c *Consumer) GetQueueStatus() (*TaskQueueStatus, error) {
|
|
||||||
// 获取队列长度
|
|
||||||
length, err := c.service.taskQueue.Size()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取活跃任务数(正在处理的任务)
|
|
||||||
activeTasks, err := c.service.GetPendingTaskCount(0) // 0表示所有用户
|
|
||||||
if err != nil {
|
|
||||||
activeTasks = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return &TaskQueueStatus{
|
|
||||||
QueueLength: int(length),
|
|
||||||
ActiveTasks: int(activeTasks),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MonitorQueue 监控队列状态
|
|
||||||
func (c *Consumer) MonitorQueue() {
|
|
||||||
ticker := time.NewTicker(30 * time.Second) // 每30秒监控一次
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
status, err := c.GetQueueStatus()
|
|
||||||
if err != nil {
|
|
||||||
jimengLogger.Errorf("get queue status failed: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if status.QueueLength > 0 || status.ActiveTasks > 0 {
|
|
||||||
jimengLogger.Infof("Jimeng queue status: queue_length=%d, active_tasks=%d",
|
|
||||||
status.QueueLength, status.ActiveTasks)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushTaskToQueue 推送任务到队列(用于手动重试)
|
|
||||||
func (c *Consumer) PushTaskToQueue(task map[string]interface{}) error {
|
|
||||||
return c.service.taskQueue.RPush(task)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTaskStats 获取任务统计信息
|
|
||||||
func (c *Consumer) GetTaskStats() (map[string]any, error) {
|
|
||||||
type StatResult struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
Count int64 `json:"count"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var stats []StatResult
|
|
||||||
err := c.service.db.Model(&model.JimengJob{}).
|
|
||||||
Select("status, COUNT(*) as count").
|
|
||||||
Group("status").
|
|
||||||
Find(&stats).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
result := map[string]any{
|
|
||||||
"total": int64(0),
|
|
||||||
"completed": int64(0),
|
|
||||||
"processing": int64(0),
|
|
||||||
"failed": int64(0),
|
|
||||||
"pending": int64(0),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, stat := range stats {
|
|
||||||
result["total"] = result["total"].(int64) + stat.Count
|
|
||||||
result[stat.Status] = stat.Count
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package jimeng
|
package jimeng
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -19,14 +20,17 @@ import (
|
|||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
|
||||||
var looger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
// Service 即梦服务
|
// Service 即梦服务(合并了消费者功能)
|
||||||
type Service struct {
|
type Service struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
client *Client
|
client *Client
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
running bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewService 创建即梦服务
|
// NewService 创建即梦服务
|
||||||
@@ -40,11 +44,68 @@ func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
|
|||||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
||||||
}
|
}
|
||||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Service{
|
return &Service{
|
||||||
db: db,
|
db: db,
|
||||||
redis: redisCli,
|
redis: redisCli,
|
||||||
taskQueue: taskQueue,
|
taskQueue: taskQueue,
|
||||||
client: client,
|
client: client,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
running: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动服务(包含消费者)
|
||||||
|
func (s *Service) Start() {
|
||||||
|
if s.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("Starting Jimeng service and task consumer...")
|
||||||
|
s.running = true
|
||||||
|
go s.consumeTasks()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止服务
|
||||||
|
func (s *Service) Stop() {
|
||||||
|
if !s.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("Stopping Jimeng service and task consumer...")
|
||||||
|
s.running = false
|
||||||
|
s.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// consumeTasks 消费任务
|
||||||
|
func (s *Service) consumeTasks() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
logger.Info("Jimeng task consumer stopped")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
s.processNextTask()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processNextTask 处理下一个任务
|
||||||
|
func (s *Service) processNextTask() {
|
||||||
|
var jobId uint
|
||||||
|
if err := s.taskQueue.LPop(&jobId); err != nil {
|
||||||
|
// 队列为空,等待1秒后重试
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("Processing Jimeng task: job_id=%d", jobId)
|
||||||
|
|
||||||
|
if err := s.ProcessTask(jobId); err != nil {
|
||||||
|
logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
|
||||||
|
s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
|
||||||
|
} else {
|
||||||
|
logger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,11 +140,7 @@ func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.Jimeng
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 推送到任务队列
|
// 推送到任务队列
|
||||||
task := map[string]any{
|
if err := s.taskQueue.RPush(job.Id); err != nil {
|
||||||
"job_id": job.Id,
|
|
||||||
"type": job.Type,
|
|
||||||
}
|
|
||||||
if err := s.taskQueue.RPush(task); err != nil {
|
|
||||||
return nil, fmt.Errorf("push jimeng task to queue failed: %w", err)
|
return nil, fmt.Errorf("push jimeng task to queue failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,40 +160,73 @@ func (s *Service) ProcessTask(jobId uint) error {
|
|||||||
return fmt.Errorf("update job status failed: %w", err)
|
return fmt.Errorf("update job status failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据任务类型处理
|
// 构建请求并提交任务
|
||||||
switch job.Type {
|
req, err := s.buildTaskRequest(&job)
|
||||||
case model.JMTaskTypeTextToImage:
|
if err != nil {
|
||||||
return s.processTextToImage(&job)
|
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
||||||
case model.JMTaskTypeImageToImage:
|
|
||||||
return s.processImageToImage(&job)
|
|
||||||
case model.JMTaskTypeImageEdit:
|
|
||||||
return s.processImageEdit(&job)
|
|
||||||
case model.JMTaskTypeImageEffects:
|
|
||||||
return s.processImageEffects(&job)
|
|
||||||
case model.JMTaskTypeTextToVideo:
|
|
||||||
return s.processTextToVideo(&job)
|
|
||||||
case model.JMTaskTypeImageToVideo:
|
|
||||||
return s.processImageToVideo(&job)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported task type: %s", job.Type)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 提交异步任务
|
||||||
|
resp, err := s.client.SubmitTask(req)
|
||||||
|
if err != nil {
|
||||||
|
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Code != 10000 {
|
||||||
|
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新任务ID和原始数据
|
||||||
|
rawData, _ := json.Marshal(resp)
|
||||||
|
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||||
|
"task_id": resp.Data.TaskId,
|
||||||
|
"raw_data": string(rawData),
|
||||||
|
"updated_at": time.Now(),
|
||||||
|
}).Error; err != nil {
|
||||||
|
logger.Errorf("update jimeng job task_id failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始轮询任务状态
|
||||||
|
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processTextToImage 处理文生图任务
|
// buildTaskRequest 构建任务请求(统一的参数解析)
|
||||||
func (s *Service) processTextToImage(job *model.JimengJob) error {
|
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
|
||||||
// 解析任务参数
|
// 解析任务参数
|
||||||
var params map[string]any
|
var params map[string]any
|
||||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
return nil, fmt.Errorf("parse task params failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建请求
|
// 构建基础请求
|
||||||
req := &SubmitTaskRequest{
|
req := &SubmitTaskRequest{
|
||||||
ReqKey: job.ReqKey,
|
ReqKey: job.ReqKey,
|
||||||
Prompt: job.Prompt,
|
Prompt: job.Prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置参数
|
// 根据任务类型设置特定参数
|
||||||
|
switch job.Type {
|
||||||
|
case model.JMTaskTypeTextToImage:
|
||||||
|
s.setTextToImageParams(req, params)
|
||||||
|
case model.JMTaskTypeImageToImage:
|
||||||
|
s.setImageToImageParams(req, params)
|
||||||
|
case model.JMTaskTypeImageEdit:
|
||||||
|
s.setImageEditParams(req, params)
|
||||||
|
case model.JMTaskTypeImageEffects:
|
||||||
|
s.setImageEffectsParams(req, params)
|
||||||
|
case model.JMTaskTypeTextToVideo:
|
||||||
|
s.setTextToVideoParams(req, params)
|
||||||
|
case model.JMTaskTypeImageToVideo:
|
||||||
|
s.setImageToVideoParams(req, params)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setTextToImageParams 设置文生图参数
|
||||||
|
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||||
if seed, ok := params["seed"]; ok {
|
if seed, ok := params["seed"]; ok {
|
||||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||||
req.Seed = seedVal
|
req.Seed = seedVal
|
||||||
@@ -162,51 +252,13 @@ func (s *Service) processTextToImage(job *model.JimengJob) error {
|
|||||||
req.UsePreLLM = usePreLlmVal
|
req.UsePreLLM = usePreLlmVal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提交异步任务
|
|
||||||
resp, err := s.client.SubmitTask(req)
|
|
||||||
if err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Code != 10000 {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务ID和原始数据
|
|
||||||
rawData, _ := json.Marshal(resp)
|
|
||||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
|
||||||
"task_id": resp.Data.TaskId,
|
|
||||||
"raw_data": string(rawData),
|
|
||||||
"updated_at": time.Now(),
|
|
||||||
}).Error; err != nil {
|
|
||||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始轮询任务状态
|
|
||||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processImageToImage 处理图生图任务
|
// setImageToImageParams 设置图生图参数
|
||||||
func (s *Service) processImageToImage(job *model.JimengJob) error {
|
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||||
// 解析任务参数
|
|
||||||
var params map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建请求
|
|
||||||
req := &SubmitTaskRequest{
|
|
||||||
ReqKey: job.ReqKey,
|
|
||||||
Prompt: job.Prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置图像输入
|
|
||||||
if imageInput, ok := params["image_input"].(string); ok {
|
if imageInput, ok := params["image_input"].(string); ok {
|
||||||
req.ImageInput = imageInput
|
req.ImageInput = imageInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置其他参数
|
|
||||||
if gpen, ok := params["gpen"]; ok {
|
if gpen, ok := params["gpen"]; ok {
|
||||||
if gpenVal, ok := gpen.(float64); ok {
|
if gpenVal, ok := gpen.(float64); ok {
|
||||||
req.Gpen = gpenVal
|
req.Gpen = gpenVal
|
||||||
@@ -225,61 +277,11 @@ func (s *Service) processImageToImage(job *model.JimengJob) error {
|
|||||||
if genMode, ok := params["gen_mode"].(string); ok {
|
if genMode, ok := params["gen_mode"].(string); ok {
|
||||||
req.GenMode = genMode
|
req.GenMode = genMode
|
||||||
}
|
}
|
||||||
if width, ok := params["width"]; ok {
|
s.setCommonParams(req, params) // 复用通用参数
|
||||||
if widthVal, ok := width.(float64); ok {
|
|
||||||
req.Width = int(widthVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if height, ok := params["height"]; ok {
|
|
||||||
if heightVal, ok := height.(float64); ok {
|
|
||||||
req.Height = int(heightVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if seed, ok := params["seed"]; ok {
|
|
||||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
|
||||||
req.Seed = seedVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提交异步任务
|
|
||||||
resp, err := s.client.SubmitTask(req)
|
|
||||||
if err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Code != 10000 {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务ID和原始数据
|
|
||||||
rawData, _ := json.Marshal(resp)
|
|
||||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
|
||||||
"task_id": resp.Data.TaskId,
|
|
||||||
"raw_data": string(rawData),
|
|
||||||
"updated_at": time.Now(),
|
|
||||||
}).Error; err != nil {
|
|
||||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始轮询任务状态
|
|
||||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processImageEdit 处理图像编辑任务
|
// setImageEditParams 设置图像编辑参数
|
||||||
func (s *Service) processImageEdit(job *model.JimengJob) error {
|
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
|
||||||
// 解析任务参数
|
|
||||||
var params map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建请求
|
|
||||||
req := &SubmitTaskRequest{
|
|
||||||
ReqKey: job.ReqKey,
|
|
||||||
Prompt: job.Prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置图像输入
|
|
||||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
if imageUrls, ok := params["image_urls"].([]any); ok {
|
||||||
for _, url := range imageUrls {
|
for _, url := range imageUrls {
|
||||||
if urlStr, ok := url.(string); ok {
|
if urlStr, ok := url.(string); ok {
|
||||||
@@ -294,57 +296,16 @@ func (s *Service) processImageEdit(job *model.JimengJob) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置其他参数
|
|
||||||
if seed, ok := params["seed"]; ok {
|
|
||||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
|
||||||
req.Seed = seedVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if scale, ok := params["scale"]; ok {
|
if scale, ok := params["scale"]; ok {
|
||||||
if scaleVal, ok := scale.(float64); ok {
|
if scaleVal, ok := scale.(float64); ok {
|
||||||
req.Scale = scaleVal
|
req.Scale = scaleVal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
s.setCommonParams(req, params)
|
||||||
// 提交异步任务
|
|
||||||
resp, err := s.client.SubmitTask(req)
|
|
||||||
if err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Code != 10000 {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务ID和原始数据
|
|
||||||
rawData, _ := json.Marshal(resp)
|
|
||||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
|
||||||
"task_id": resp.Data.TaskId,
|
|
||||||
"raw_data": string(rawData),
|
|
||||||
"updated_at": time.Now(),
|
|
||||||
}).Error; err != nil {
|
|
||||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始轮询任务状态
|
|
||||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processImageEffects 处理图像特效任务
|
// setImageEffectsParams 设置图像特效参数
|
||||||
func (s *Service) processImageEffects(job *model.JimengJob) error {
|
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
|
||||||
// 解析任务参数
|
|
||||||
var params map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建请求
|
|
||||||
req := &SubmitTaskRequest{
|
|
||||||
ReqKey: job.ReqKey,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置图像输入
|
|
||||||
if imageInput1, ok := params["image_input1"].(string); ok {
|
if imageInput1, ok := params["image_input1"].(string); ok {
|
||||||
req.ImageInput1 = imageInput1
|
req.ImageInput1 = imageInput1
|
||||||
}
|
}
|
||||||
@@ -361,141 +322,41 @@ func (s *Service) processImageEffects(job *model.JimengJob) error {
|
|||||||
req.Height = int(heightVal)
|
req.Height = int(heightVal)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提交异步任务
|
|
||||||
resp, err := s.client.SubmitTask(req)
|
|
||||||
if err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Code != 10000 {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务ID和原始数据
|
|
||||||
rawData, _ := json.Marshal(resp)
|
|
||||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
|
||||||
"task_id": resp.Data.TaskId,
|
|
||||||
"raw_data": string(rawData),
|
|
||||||
"updated_at": time.Now(),
|
|
||||||
}).Error; err != nil {
|
|
||||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始轮询任务状态
|
|
||||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processTextToVideo 处理文生视频任务
|
// setTextToVideoParams 设置文生视频参数
|
||||||
func (s *Service) processTextToVideo(job *model.JimengJob) error {
|
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||||
// 解析任务参数
|
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||||
var params map[string]any
|
req.AspectRatio = aspectRatio
|
||||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
|
||||||
}
|
}
|
||||||
|
s.setCommonParams(req, params)
|
||||||
|
}
|
||||||
|
|
||||||
// 构建请求
|
// setImageToVideoParams 设置图生视频参数
|
||||||
req := &SubmitTaskRequest{
|
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||||
ReqKey: job.ReqKey,
|
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
|
||||||
Prompt: job.Prompt,
|
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||||
|
req.AspectRatio = aspectRatio
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 设置参数
|
// setCommonParams 设置通用参数(seed, width, height等)
|
||||||
|
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
|
||||||
if seed, ok := params["seed"]; ok {
|
if seed, ok := params["seed"]; ok {
|
||||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||||
req.Seed = seedVal
|
req.Seed = seedVal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
if width, ok := params["width"]; ok {
|
||||||
req.AspectRatio = aspectRatio
|
if widthVal, ok := width.(float64); ok {
|
||||||
}
|
req.Width = int(widthVal)
|
||||||
|
|
||||||
// 提交异步任务
|
|
||||||
resp, err := s.client.SubmitTask(req)
|
|
||||||
if err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Code != 10000 {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务ID和原始数据
|
|
||||||
rawData, _ := json.Marshal(resp)
|
|
||||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
|
||||||
"task_id": resp.Data.TaskId,
|
|
||||||
"raw_data": string(rawData),
|
|
||||||
"updated_at": time.Now(),
|
|
||||||
}).Error; err != nil {
|
|
||||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始轮询任务状态
|
|
||||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// processImageToVideo 处理图生视频任务
|
|
||||||
func (s *Service) processImageToVideo(job *model.JimengJob) error {
|
|
||||||
// 解析任务参数
|
|
||||||
var params map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建请求
|
|
||||||
req := &SubmitTaskRequest{
|
|
||||||
ReqKey: job.ReqKey,
|
|
||||||
Prompt: job.Prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置图像输入
|
|
||||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
|
||||||
for _, url := range imageUrls {
|
|
||||||
if urlStr, ok := url.(string); ok {
|
|
||||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
if height, ok := params["height"]; ok {
|
||||||
for _, data := range binaryData {
|
if heightVal, ok := height.(float64); ok {
|
||||||
if dataStr, ok := data.(string); ok {
|
req.Height = int(heightVal)
|
||||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置其他参数
|
|
||||||
if seed, ok := params["seed"]; ok {
|
|
||||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
|
||||||
req.Seed = seedVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
|
||||||
req.AspectRatio = aspectRatio
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提交异步任务
|
|
||||||
resp, err := s.client.SubmitTask(req)
|
|
||||||
if err != nil {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Code != 10000 {
|
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务ID和原始数据
|
|
||||||
rawData, _ := json.Marshal(resp)
|
|
||||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
|
||||||
"task_id": resp.Data.TaskId,
|
|
||||||
"raw_data": string(rawData),
|
|
||||||
"updated_at": time.Now(),
|
|
||||||
}).Error; err != nil {
|
|
||||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 开始轮询任务状态
|
|
||||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// pollTaskStatus 轮询任务状态
|
// pollTaskStatus 轮询任务状态
|
||||||
@@ -514,7 +375,7 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
looger.Errorf("query jimeng task status failed: %v", err)
|
logger.Errorf("query jimeng task status failed: %v", err)
|
||||||
retryCount++
|
retryCount++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -537,7 +398,6 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
|||||||
// 任务完成,更新结果
|
// 任务完成,更新结果
|
||||||
updates := map[string]any{
|
updates := map[string]any{
|
||||||
"status": model.JMTaskStatusSuccess,
|
"status": model.JMTaskStatusSuccess,
|
||||||
"progress": 100,
|
|
||||||
"updated_at": time.Now(),
|
"updated_at": time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -553,18 +413,18 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
|||||||
|
|
||||||
case model.JMTaskStatusInQueue:
|
case model.JMTaskStatusInQueue:
|
||||||
// 任务在队列中
|
// 任务在队列中
|
||||||
s.UpdateJobProgress(jobId, 10)
|
s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "")
|
||||||
|
|
||||||
case model.JMTaskStatusGenerating:
|
case model.JMTaskStatusGenerating:
|
||||||
// 任务处理中
|
// 任务处理中
|
||||||
s.UpdateJobProgress(jobId, 50)
|
s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "")
|
||||||
|
|
||||||
case model.JMTaskStatusNotFound:
|
case model.JMTaskStatusNotFound:
|
||||||
// 任务未找到或已过期
|
// 任务未找到或已过期
|
||||||
return s.handleTaskError(jobId, fmt.Sprintf("task not found or expired: %s", resp.Data.Status))
|
return s.handleTaskError(jobId, resp.Message)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
looger.Warnf("unknown task status: %s", resp.Data.Status)
|
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
retryCount++
|
retryCount++
|
||||||
@@ -586,20 +446,49 @@ func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg
|
|||||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
|
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateJobProgress 更新任务进度
|
|
||||||
func (s *Service) UpdateJobProgress(jobId uint, progress int) error {
|
|
||||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(map[string]any{
|
|
||||||
"progress": progress,
|
|
||||||
"updated_at": time.Now(),
|
|
||||||
}).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleTaskError 处理任务错误
|
// handleTaskError 处理任务错误
|
||||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||||
looger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||||
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PushTaskToQueue 推送任务到队列(用于手动重试)
|
||||||
|
func (s *Service) PushTaskToQueue(jobId uint) error {
|
||||||
|
return s.taskQueue.RPush(jobId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTaskStats 获取任务统计信息
|
||||||
|
func (s *Service) GetTaskStats() (map[string]any, error) {
|
||||||
|
type StatResult struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Count int64 `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var stats []StatResult
|
||||||
|
err := s.db.Model(&model.JimengJob{}).
|
||||||
|
Select("status, COUNT(*) as count").
|
||||||
|
Group("status").
|
||||||
|
Find(&stats).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := map[string]any{
|
||||||
|
"total": int64(0),
|
||||||
|
"completed": int64(0),
|
||||||
|
"processing": int64(0),
|
||||||
|
"failed": int64(0),
|
||||||
|
"pending": int64(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stat := range stats {
|
||||||
|
result["total"] = result["total"].(int64) + stat.Count
|
||||||
|
result[stat.Status] = stat.Count
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetJob 获取任务
|
// GetJob 获取任务
|
||||||
func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||||
var job model.JimengJob
|
var job model.JimengJob
|
||||||
@@ -609,54 +498,16 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
|||||||
return &job, nil
|
return &job, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetJobByPage 分页获取任务列表
|
|
||||||
func (s *Service) GetJobByPage(userId uint, page, pageSize int) ([]*model.JimengJob, int64, error) {
|
|
||||||
var jobs []*model.JimengJob
|
|
||||||
var total int64
|
|
||||||
|
|
||||||
query := s.db.Model(&model.JimengJob{})
|
|
||||||
if userId > 0 {
|
|
||||||
query = query.Where("user_id = ?", userId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 统计总数
|
|
||||||
if err := query.Count(&total).Error; err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 分页查询
|
|
||||||
offset := (page - 1) * pageSize
|
|
||||||
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&jobs).Error; err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return jobs, total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPendingTaskCount 获取用户未完成任务数量
|
|
||||||
func (s *Service) GetPendingTaskCount(userId uint) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
err := s.db.Model(&model.JimengJob{}).Where("user_id = ? AND status IN (?)", userId,
|
|
||||||
[]model.JMTaskStatus{model.JMTaskStatusInQueue, model.JMTaskStatusGenerating}).Count(&count).Error
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteJob 删除任务
|
// DeleteJob 删除任务
|
||||||
func (s *Service) DeleteJob(jobId uint, userId uint) error {
|
func (s *Service) DeleteJob(jobId uint, userId uint) error {
|
||||||
return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error
|
return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// PushTaskToQueue 推送任务到队列
|
|
||||||
func (s *Service) PushTaskToQueue(task map[string]any) error {
|
|
||||||
return s.taskQueue.RPush(task)
|
|
||||||
}
|
|
||||||
|
|
||||||
// testConnection 测试即梦AI连接
|
// testConnection 测试即梦AI连接
|
||||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||||
testClient := NewClient(accessKey, secretKey)
|
testClient := NewClient(accessKey, secretKey)
|
||||||
|
|
||||||
// 使用一个简单的查询任务来测试连接
|
// 使用一个简单的查询任务来测试连接
|
||||||
// 这里使用一个不存在的任务ID来测试API连接是否正常
|
|
||||||
testReq := &QueryTaskRequest{
|
testReq := &QueryTaskRequest{
|
||||||
ReqKey: "test_connection",
|
ReqKey: "test_connection",
|
||||||
TaskId: "test_task_id_12345",
|
TaskId: "test_task_id_12345",
|
||||||
|
|||||||
Reference in New Issue
Block a user