文生视频和图生视频功能完成

This commit is contained in:
GeekMaster
2025-07-23 19:11:30 +08:00
parent 54fe49de5d
commit a3f6a641aa
20 changed files with 640 additions and 610 deletions

View File

@@ -293,7 +293,7 @@ func (s *Service) DownloadImages() {
func (s *Service) downloadImage(jobId uint, orgURL string) (string, error) {
// sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, ".png", false)
if err != nil {
return "", err
}

View File

@@ -11,6 +11,7 @@ import (
"gorm.io/gorm"
logger2 "geekai/logger"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
@@ -31,10 +32,11 @@ type Service struct {
ctx context.Context
cancel context.CancelFunc
running bool
uploader *oss.UploaderManager
}
// NewService 创建即梦服务
func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
// 从数据库加载配置
var config model.Config
@@ -54,6 +56,7 @@ func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
ctx: ctx,
cancel: cancel,
running: false,
uploader: uploader,
}
}
@@ -65,6 +68,7 @@ func (s *Service) Start() {
logger.Info("Starting Jimeng service and task consumer...")
s.running = true
go s.consumeTasks()
go s.pollTaskStatus()
}
// Stop 停止服务
@@ -166,6 +170,8 @@ func (s *Service) ProcessTask(jobId uint) error {
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
}
logger.Infof("提交即梦任务: %+v", req)
// 提交异步任务
resp, err := s.client.SubmitTask(req)
if err != nil {
@@ -186,8 +192,7 @@ func (s *Service) ProcessTask(jobId uint) error {
logger.Errorf("update jimeng job task_id failed: %v", err)
}
// 开始轮询任务状态
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
return nil
}
// buildTaskRequest 构建任务请求(统一的参数解析)
@@ -360,78 +365,100 @@ func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any)
}
// pollTaskStatus 轮询任务状态
func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
maxRetries := 60 // 最大重试次数60次 * 5秒 = 5分钟
retryCount := 0
func (s *Service) pollTaskStatus() {
for retryCount < maxRetries {
time.Sleep(5 * time.Second) // 等待5秒
// 查询任务状态
resp, err := s.client.QueryTask(&QueryTaskRequest{
ReqKey: reqKey,
TaskId: taskId,
ReqJson: `{"return_url":true}`,
})
if err != nil {
logger.Errorf("query jimeng task status failed: %v", err)
retryCount++
for {
var jobs []model.JimengJob
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
if len(jobs) == 0 {
logger.Debugf("no jimeng task to poll, sleep 10s")
time.Sleep(10 * time.Second)
continue
}
// 更新原始数据
rawData, _ := json.Marshal(resp)
s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Update("raw_data", string(rawData))
for _, job := range jobs {
// 任务超时处理
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
s.handleTaskError(job.Id, "task timeout")
continue
}
// 查询任务状态
resp, err := s.client.QueryTask(&QueryTaskRequest{
ReqKey: job.ReqKey,
TaskId: job.TaskId,
ReqJson: `{"return_url":true}`,
})
if err != nil {
logger.Errorf("query jimeng task status failed: %v", err)
continue
}
// 更新原始数据
rawData, _ := json.Marshal(resp)
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
if resp.Code != 10000 {
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
continue
}
switch resp.Data.Status {
case model.JMTaskStatusDone:
// 判断任务是否成功
if resp.Message != "Success" {
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
continue
}
// 任务完成,更新结果
updates := map[string]any{
"status": model.JMTaskStatusSuccess,
"updated_at": time.Now(),
}
// 设置结果URL
if len(resp.Data.ImageUrls) > 0 {
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.ImageUrls[0], ".png", false)
if err != nil {
logger.Errorf("upload image failed: %v", err)
imgUrl = resp.Data.ImageUrls[0]
}
updates["img_url"] = imgUrl
}
if resp.Data.VideoUrl != "" {
videoUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.VideoUrl, ".mp4", false)
if err != nil {
logger.Errorf("upload video failed: %v", err)
videoUrl = resp.Data.VideoUrl
}
updates["video_url"] = videoUrl
}
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
// 任务处理中
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
case model.JMTaskStatusNotFound:
// 任务未找到
s.handleTaskError(job.Id, "task not found")
case model.JMTaskStatusExpired:
// 任务过期
s.handleTaskError(job.Id, "task expired")
default:
logger.Warnf("unknown task status: %s", resp.Data.Status)
}
if resp.Code != 10000 {
return s.handleTaskError(jobId, fmt.Sprintf("query task failed: %s", resp.Message))
}
switch resp.Data.Status {
case model.JMTaskStatusDone:
// 判断任务是否成功
if resp.Message != "Success" {
return s.handleTaskError(jobId, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
}
time.Sleep(5 * time.Second)
// 任务完成,更新结果
updates := map[string]any{
"status": model.JMTaskStatusSuccess,
"updated_at": time.Now(),
}
// 设置结果URL
if len(resp.Data.ImageUrls) > 0 {
updates["img_url"] = resp.Data.ImageUrls[0]
}
if resp.Data.VideoUrl != "" {
updates["video_url"] = resp.Data.VideoUrl
}
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
case model.JMTaskStatusInQueue:
// 任务在队列中
s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "")
case model.JMTaskStatusGenerating:
// 任务处理中
s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "")
case model.JMTaskStatusNotFound:
// 任务未找到或已过期
return s.handleTaskError(jobId, resp.Message)
default:
logger.Warnf("unknown task status: %s", resp.Data.Status)
}
retryCount++
}
// 超时处理
return s.handleTaskError(jobId, "task timeout")
}
// UpdateJobStatus 更新任务状态
@@ -498,11 +525,6 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
return &job, nil
}
// DeleteJob 删除任务
func (s *Service) DeleteJob(jobId uint, userId uint) error {
return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error
}
// testConnection 测试即梦AI连接
func (s *Service) testConnection(accessKey, secretKey string) error {
testClient := NewClient(accessKey, secretKey)

View File

@@ -191,7 +191,7 @@ func (s *Service) DownloadImages() {
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, ".png", proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)

View File

@@ -84,7 +84,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil
}
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
@@ -99,8 +99,10 @@ func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := utils.GetImgExt(parse.Path)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
if ext == "" {
ext = filepath.Ext(parse.Path)
}
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
if err != nil {

View File

@@ -12,11 +12,12 @@ import (
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/gin-gonic/gin"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
)
type LocalStorage struct {
@@ -37,7 +38,7 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
return File{}, fmt.Errorf("error with get form: %v", err)
}
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, "")
if err != nil {
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
}
@@ -57,13 +58,13 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil
}
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
func (s LocalStorage) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
parse, err := url.Parse(fileURL)
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
filename := filepath.Base(parse.Path)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, ext)
if err != nil {
return "", fmt.Errorf("error with generate image dir: %v", err)
}
@@ -85,7 +86,7 @@ func (s LocalStorage) PutBase64(base64Img string) (string, error) {
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
filePath, _ := utils.GenUploadPath(s.config.BasePath, "", ".png")
err = os.WriteFile(filePath, imageData, 0644)
if err != nil {
return "", fmt.Errorf("error writing to file:%v", err)

View File

@@ -44,7 +44,7 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
}
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
@@ -59,8 +59,10 @@ func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := filepath.Ext(parse.Path)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
if ext == "" {
ext = filepath.Ext(parse.Path)
}
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
@@ -86,7 +88,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
defer fileReader.Close()
fileExt := utils.GetImgExt(file.Filename)
fileExt := filepath.Ext(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Body-Type"),

View File

@@ -93,7 +93,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
@@ -108,8 +108,10 @@ func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
if err != nil {
return "", fmt.Errorf("error with parse image URL: %v", err)
}
fileExt := utils.GetImgExt(parse.Path)
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
if ext == "" {
ext = filepath.Ext(parse.Path)
}
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据

View File

@@ -23,7 +23,7 @@ type File struct {
}
type Uploader interface {
PutFile(ctx *gin.Context, name string) (File, error)
PutUrlFile(url string, useProxy bool) (string, error)
PutUrlFile(url string, ext string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error)
Delete(fileURL string) error
}

View File

@@ -272,14 +272,14 @@ func (s *Service) DownloadFiles() {
for _, v := range items {
// 下载图片和音频
logger.Infof("try download cover image: %s", v.CoverURL)
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, ".png", true)
if err != nil {
logger.Errorf("download image with error: %v", err)
continue
}
logger.Infof("try download audio: %s", v.AudioURL)
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, ".mp3", true)
if err != nil {
logger.Errorf("download audio with error: %v", err)
continue

View File

@@ -164,7 +164,7 @@ func (s *Service) DownloadFiles() {
}
logger.Infof("try download video: %s", v.WaterURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, ".mp4", true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
@@ -174,7 +174,7 @@ func (s *Service) DownloadFiles() {
if v.VideoURL != "" {
logger.Infof("try download no water video: %s", v.VideoURL)
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, ".mp4", true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue