文生视频和图生视频功能完成

This commit is contained in:
GeekMaster
2025-07-23 19:11:30 +08:00
parent 54fe49de5d
commit a3f6a641aa
20 changed files with 640 additions and 610 deletions

View File

@@ -11,6 +11,7 @@ import (
"gorm.io/gorm"
logger2 "geekai/logger"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
@@ -31,10 +32,11 @@ type Service struct {
ctx context.Context
cancel context.CancelFunc
running bool
uploader *oss.UploaderManager
}
// NewService 创建即梦服务
func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
// 从数据库加载配置
var config model.Config
@@ -54,6 +56,7 @@ func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
ctx: ctx,
cancel: cancel,
running: false,
uploader: uploader,
}
}
@@ -65,6 +68,7 @@ func (s *Service) Start() {
logger.Info("Starting Jimeng service and task consumer...")
s.running = true
go s.consumeTasks()
go s.pollTaskStatus()
}
// Stop 停止服务
@@ -166,6 +170,8 @@ func (s *Service) ProcessTask(jobId uint) error {
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
}
logger.Infof("提交即梦任务: %+v", req)
// 提交异步任务
resp, err := s.client.SubmitTask(req)
if err != nil {
@@ -186,8 +192,7 @@ func (s *Service) ProcessTask(jobId uint) error {
logger.Errorf("update jimeng job task_id failed: %v", err)
}
// 开始轮询任务状态
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
return nil
}
// buildTaskRequest 构建任务请求(统一的参数解析)
@@ -360,78 +365,100 @@ func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any)
}
// pollTaskStatus 轮询任务状态
func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
maxRetries := 60 // 最大重试次数60次 * 5秒 = 5分钟
retryCount := 0
func (s *Service) pollTaskStatus() {
for retryCount < maxRetries {
time.Sleep(5 * time.Second) // 等待5秒
// 查询任务状态
resp, err := s.client.QueryTask(&QueryTaskRequest{
ReqKey: reqKey,
TaskId: taskId,
ReqJson: `{"return_url":true}`,
})
if err != nil {
logger.Errorf("query jimeng task status failed: %v", err)
retryCount++
for {
var jobs []model.JimengJob
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
if len(jobs) == 0 {
logger.Debugf("no jimeng task to poll, sleep 10s")
time.Sleep(10 * time.Second)
continue
}
// 更新原始数据
rawData, _ := json.Marshal(resp)
s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Update("raw_data", string(rawData))
for _, job := range jobs {
// 任务超时处理
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
s.handleTaskError(job.Id, "task timeout")
continue
}
// 查询任务状态
resp, err := s.client.QueryTask(&QueryTaskRequest{
ReqKey: job.ReqKey,
TaskId: job.TaskId,
ReqJson: `{"return_url":true}`,
})
if err != nil {
logger.Errorf("query jimeng task status failed: %v", err)
continue
}
// 更新原始数据
rawData, _ := json.Marshal(resp)
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
if resp.Code != 10000 {
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
continue
}
switch resp.Data.Status {
case model.JMTaskStatusDone:
// 判断任务是否成功
if resp.Message != "Success" {
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
continue
}
// 任务完成,更新结果
updates := map[string]any{
"status": model.JMTaskStatusSuccess,
"updated_at": time.Now(),
}
// 设置结果URL
if len(resp.Data.ImageUrls) > 0 {
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.ImageUrls[0], ".png", false)
if err != nil {
logger.Errorf("upload image failed: %v", err)
imgUrl = resp.Data.ImageUrls[0]
}
updates["img_url"] = imgUrl
}
if resp.Data.VideoUrl != "" {
videoUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.VideoUrl, ".mp4", false)
if err != nil {
logger.Errorf("upload video failed: %v", err)
videoUrl = resp.Data.VideoUrl
}
updates["video_url"] = videoUrl
}
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
// 任务处理中
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
case model.JMTaskStatusNotFound:
// 任务未找到
s.handleTaskError(job.Id, "task not found")
case model.JMTaskStatusExpired:
// 任务过期
s.handleTaskError(job.Id, "task expired")
default:
logger.Warnf("unknown task status: %s", resp.Data.Status)
}
if resp.Code != 10000 {
return s.handleTaskError(jobId, fmt.Sprintf("query task failed: %s", resp.Message))
}
switch resp.Data.Status {
case model.JMTaskStatusDone:
// 判断任务是否成功
if resp.Message != "Success" {
return s.handleTaskError(jobId, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
}
time.Sleep(5 * time.Second)
// 任务完成,更新结果
updates := map[string]any{
"status": model.JMTaskStatusSuccess,
"updated_at": time.Now(),
}
// 设置结果URL
if len(resp.Data.ImageUrls) > 0 {
updates["img_url"] = resp.Data.ImageUrls[0]
}
if resp.Data.VideoUrl != "" {
updates["video_url"] = resp.Data.VideoUrl
}
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
case model.JMTaskStatusInQueue:
// 任务在队列中
s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "")
case model.JMTaskStatusGenerating:
// 任务处理中
s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "")
case model.JMTaskStatusNotFound:
// 任务未找到或已过期
return s.handleTaskError(jobId, resp.Message)
default:
logger.Warnf("unknown task status: %s", resp.Data.Status)
}
retryCount++
}
// 超时处理
return s.handleTaskError(jobId, "task timeout")
}
// UpdateJobStatus 更新任务状态
@@ -498,11 +525,6 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
return &job, nil
}
// DeleteJob 删除任务
func (s *Service) DeleteJob(jobId uint, userId uint) error {
return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error
}
// testConnection 测试即梦AI连接
func (s *Service) testConnection(accessKey, secretKey string) error {
testClient := NewClient(accessKey, secretKey)