From 2c6eee7fc1e52b5f683ca82b2375120c463eaf13 Mon Sep 17 00:00:00 2001 From: GeekMaster Date: Fri, 12 Sep 2025 18:58:52 +0800 Subject: [PATCH] acommpelish jimeng AI refactor for PC --- api/core/types/jimeng.go | 16 +- api/go.mod | 2 + api/go.sum | 6 + api/handler/jimeng_handler.go | 74 +- api/service/jimeng/client.go | 54 +- api/service/jimeng/service.go | 243 ++---- api/service/jimeng/types.go | 33 +- api/utils/media_duration.go | 817 ++++++++++++++++++++ web/src/components/ParamBuilder.vue | 2 +- web/src/components/ui/Alert.vue | 2 +- web/src/store/data/jimeng_data.js | 16 + web/src/store/jimeng.js | 19 +- web/src/views/admin/jimeng/JimengConfig.vue | 31 +- 13 files changed, 1049 insertions(+), 266 deletions(-) create mode 100644 api/utils/media_duration.go diff --git a/api/core/types/jimeng.go b/api/core/types/jimeng.go index b186ceab..29fd87f8 100644 --- a/api/core/types/jimeng.go +++ b/api/core/types/jimeng.go @@ -2,9 +2,13 @@ package types // JimengConfig 即梦AI配置 type JimengConfig struct { - AccessKey string `json:"access_key"` - SecretKey string `json:"secret_key"` - Power JimengPower `json:"power"` + // 即梦AI的AccessKey和SecretKey + AccessKey string `json:"access_key"` + SecretKey string `json:"secret_key"` + // 火山引擎大模型专用的验证方式 + ApiKey string `json:"api_key"` + // 算力配置 + Power JimengPower `json:"power"` } // JimengPower 即梦AI算力配置 @@ -40,7 +44,9 @@ const ( // JimengTaskRequest 即梦AI任务请求 type JimengTaskRequest struct { - ReqKey string `json:"req_key"` // 请求Key + TaskType JMTaskType `json:"type"` // 任务类型 + ReqKey string `json:"req_key"` // 请求Key + Power int `json:"power"` // 消耗算力 // 公共参数 Prompt string `json:"prompt,omitempty"` ImageUrls []string `json:"image_urls,omitempty"` @@ -50,7 +56,7 @@ type JimengTaskRequest struct { UsePreLLM bool `json:"use_pre_llm,omitempty"` // 视频生成参数 - Duration string `json:"duration,omitempty"` // 视频时长 + Duration int `json:"duration,omitempty"` // 视频时长,单位:秒 TemplateId string `json:"template_id,omitempty"` // 运镜模板ID AspectRatio string `json:"aspect_ratio,omitempty"` CameraStrength string `json:"camera_strength,omitempty"` // 运镜强度 diff --git a/api/go.mod b/api/go.mod index 62da5307..8ae01b21 100644 --- a/api/go.mod +++ b/api/go.mod @@ -33,6 +33,7 @@ require ( github.com/shirou/gopsutil v3.21.11+incompatible github.com/shopspring/decimal v1.3.1 github.com/syndtr/goleveldb v1.0.0 + github.com/volcengine/volcengine-go-sdk v1.1.34 golang.org/x/image v0.15.0 ) @@ -50,6 +51,7 @@ require ( github.com/tklauser/numcpus v0.7.0 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.uber.org/mock v0.4.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) require ( diff --git a/api/go.sum b/api/go.sum index 702add84..0a8b0d75 100644 --- a/api/go.sum +++ b/api/go.sum @@ -100,6 +100,7 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -110,6 +111,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA= @@ -259,6 +261,8 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8= github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU= +github.com/volcengine/volcengine-go-sdk v1.1.34 h1:ha90JycCCTJNCse0UDziBgBsuX98ITOrkwYlDWcm7NI= +github.com/volcengine/volcengine-go-sdk v1.1.34/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= @@ -390,6 +394,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/api/handler/jimeng_handler.go b/api/handler/jimeng_handler.go index f9c7b0f0..f3dbde7e 100644 --- a/api/handler/jimeng_handler.go +++ b/api/handler/jimeng_handler.go @@ -1,6 +1,7 @@ package handler import ( + "errors" "fmt" "geekai/core" "geekai/core/middleware" @@ -94,32 +95,29 @@ func (h *JimengHandler) CreateTask(c *gin.Context) { } // 获取算力消耗 + powerCost, err := h.getTaskPower(req) + if err != nil { + resp.ERROR(c, "计算任务消耗积分失败") + return + } - // if user.Power < powerCost { - // resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost)) - // return - // } + if user.Power < powerCost { + resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost)) + return + } - // taskReq := &jimeng.CreateTaskRequest{ - // Type: taskType, - // Prompt: req.Prompt, - // Params: params, - // ReqKey: reqKey, - // Power: powerCost, - // } + job, err := h.jimengService.CreateTask(user.Id, &req) + if err != nil { + logger.Errorf("create jimeng task failed: %v", err) + resp.ERROR(c, "创建任务失败") + return + } - // job, err := h.jimengService.CreateTask(user.Id, taskReq) - // if err != nil { - // logger.Errorf("create jimeng task failed: %v", err) - // resp.ERROR(c, "创建任务失败") - // return - // } - - // h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{ - // Type: types.PowerConsume, - // Model: "jimeng", - // Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id), - // }) + h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{ + Type: types.PowerConsume, + Model: job.ReqKey, + Remark: fmt.Sprintf("%s,任务ID:%d", req.ReqKey, job.Id), + }) resp.SUCCESS(c) } @@ -224,7 +222,7 @@ func (h *JimengHandler) Remove(c *gin.Context) { if job.Status != types.JMTaskStatusFailed { err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{ Type: types.PowerRefund, - Model: "jimeng", + Model: job.ReqKey, Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power), }) if err != nil { @@ -285,20 +283,24 @@ func (h *JimengHandler) Retry(c *gin.Context) { } // getPowerFromConfig 从配置中获取指定类型的算力消耗 -func (h *JimengHandler) getPowerFromConfig(taskType types.JMTaskType) int { +func (h *JimengHandler) getTaskPower(req types.JimengTaskRequest) (int, error) { config := h.App.SysConfig.Jimeng - - switch taskType { + switch req.TaskType { case types.JMTaskTypeImage: - return config.Power.Image + return config.Power.Image, nil case types.JMTaskTypeVideo: - return config.Power.Video + if req.Duration == 0 { + return 0, errors.New("视频时长不能为0") + } + return config.Power.Video * req.Duration, nil case types.JMTaskTypeVirtualHuman: - return config.Power.VirtualHuman + // TODO 计算音频时长 + return config.Power.VirtualHuman, nil case types.JMTaskTypeActionTransfer: - return config.Power.ActionTransfer + // TODO 计算视频时长 + return config.Power.ActionTransfer, nil default: - return 10 + return 0, errors.New("任务类型不支持") } } @@ -306,9 +308,9 @@ func (h *JimengHandler) getPowerFromConfig(taskType types.JMTaskType) int { func (h *JimengHandler) GetPowerConfig(c *gin.Context) { config := h.App.SysConfig.Jimeng resp.SUCCESS(c, gin.H{ - "image": config.Power.Image, - "video": config.Power.Video, - "image_edit": config.Power.VirtualHuman, - "image_effects": config.Power.ActionTransfer, + "image": config.Power.Image, + "video": config.Power.Video, + "virtual_human": config.Power.VirtualHuman, + "action_transfer": config.Power.ActionTransfer, }) } diff --git a/api/service/jimeng/client.go b/api/service/jimeng/client.go index 87864488..3034c4a5 100644 --- a/api/service/jimeng/client.go +++ b/api/service/jimeng/client.go @@ -1,6 +1,7 @@ package jimeng import ( + "context" "encoding/json" "fmt" "geekai/core/types" @@ -10,6 +11,9 @@ import ( "github.com/volcengine/volc-sdk-golang/base" "github.com/volcengine/volc-sdk-golang/service/visual" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "github.com/volcengine/volcengine-go-sdk/volcengine" ) // Client 即梦API客户端 @@ -94,13 +98,19 @@ func (c *Client) testConnection() error { } // SubmitTask 提交异步任务 -func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) { +func (c *Client) SubmitTask(req map[string]any) (*SubmitTaskResponse, error) { // 直接将请求转为map[string]interface{} reqBodyBytes, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("marshal request failed: %w", err) } + // 单独处理图片特效任务 + if req["req_key"] == ImageEffectReqKey { + req["image_input1"] = req["image_urls"].([]any)[0] + delete(req, "image_urls") + } + // 直接使用序列化后的字节 jsonBody := reqBodyBytes @@ -146,27 +156,29 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) { return &result, nil } -// SubmitSyncTask 提交同步任务(仅用于文生图) -func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) { - // 序列化请求 - jsonBody, err := json.Marshal(req) +// SubmitSyncImageTask 提交同步生图任务 +func (c *Client) SubmitSyncImageTask(req types.JimengTaskRequest) (*model.ImagesResponse, error) { + // 配置火山引擎访问密钥,目前只支持API Key验证 + client := arkruntime.NewClientWithApiKey(c.config.ApiKey) + // 构造生图请求 + sequentialImageGeneration := model.SequentialImageGeneration("disabled") + generateReq := model.GenerateImagesRequest{ + Model: req.ReqKey, // 模型名称 + Prompt: req.Prompt, // 提示词 + Size: volcengine.String(req.Size), // 图片尺寸 + SequentialImageGeneration: &sequentialImageGeneration, // 禁用序列生成 + ResponseFormat: volcengine.String(model.GenerateImagesResponseFormatURL), // 响应格式为 URL + Watermark: volcengine.Bool(false), // 不添加水印 + OptimizePrompt: volcengine.Bool(true), // 优化提示词 + } + if len(req.ImageUrls) > 0 { + generateReq.Image = req.ImageUrls + } + // 调用生图 API + resp, err := client.GenerateImages(context.Background(), generateReq) if err != nil { - return nil, fmt.Errorf("marshal request failed: %w", err) + return nil, err } - // 调用SDK的JSON方法 - respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody)) - if err != nil { - return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err) - } - - logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody)) - - // 解析响应,同步任务直接返回结果 - var result QueryTaskResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("unmarshal response failed: %w", err) - } - - return &result, nil + return &resp, nil } diff --git a/api/service/jimeng/service.go b/api/service/jimeng/service.go index 9ddc41a6..ad9be8bb 100644 --- a/api/service/jimeng/service.go +++ b/api/service/jimeng/service.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "fmt" - "strconv" + "strings" "time" "gorm.io/gorm" @@ -103,24 +103,18 @@ func (s *Service) processNextTask() { } // CreateTask 创建任务 -func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) { +func (s *Service) CreateTask(userId uint, req *types.JimengTaskRequest) (*model.JimengJob, error) { // 生成任务ID taskId := utils.RandString(20) - // 序列化任务参数 - paramsJson, err := json.Marshal(req.Params) - if err != nil { - return nil, fmt.Errorf("marshal task params failed: %w", err) - } - // 创建任务记录 job := &model.JimengJob{ UserId: userId, TaskId: taskId, - Type: req.Type, + Type: req.TaskType, ReqKey: req.ReqKey, Prompt: req.Prompt, - Params: string(paramsJson), + Params: utils.JsonEncode(req), Status: types.JMTaskStatusInQueue, Power: req.Power, CreatedAt: time.Now(), @@ -153,21 +147,61 @@ func (s *Service) ProcessTask(jobId uint) error { return fmt.Errorf("update job status failed: %w", err) } + // 解析任务参数 + var req types.JimengTaskRequest + err := utils.JsonDecode(job.Params, &req) + if err != nil { + return fmt.Errorf("parse task params failed: %w", err) + } + // 构建请求并提交任务 - req, err := s.buildTaskRequest(&job) + params, err := s.buildTaskRequest(&req) if err != nil { return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err)) } - logger.Infof("提交即梦任务: %+v", req) + logger.Debugf("提交即梦任务: %+v", params) - // 提交异步任务 - resp, err := s.client.SubmitTask(req) + // 同步任务 ,后台执行 + if req.ReqKey == DoubaoSeedream40ReqKey { + go func() { + resp, err := s.client.SubmitSyncImageTask(req) + if err != nil { + _ = s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) + return + } + logger.Infof("同步任务提交成功: %+v", resp) + // 更新原始数据 + rawData, _ := json.Marshal(resp) + updates := map[string]any{ + "raw_data": string(rawData), + } + if resp.Error != nil { + updates["status"] = types.JMTaskStatusFailed + updates["err_msg"] = resp.Error.Message + s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates) + return + } + + // 更新任务状态 + updates["status"] = types.JMTaskStatusSuccess + // 下载图片 + imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(*resp.Data[0].Url, ".png", false) + if err == nil { + updates["img_url"] = imgUrl + } + s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates) + }() + return nil + } + + // 异步任务 ,前台执行 + resp, err := s.client.SubmitTask(params) if err != nil { return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) } - if resp.Code != 10000 { + if resp.Code != CodeSuccess { return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) } @@ -185,168 +219,36 @@ func (s *Service) ProcessTask(jobId uint) error { } // buildTaskRequest 构建任务请求(统一的参数解析) -func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) { - // 解析任务参数 +func (s *Service) buildTaskRequest(req *types.JimengTaskRequest) (map[string]any, error) { var params map[string]any - if err := json.Unmarshal([]byte(job.Params), ¶ms); err != nil { + err := utils.JsonDecode(utils.JsonEncode(req), ¶ms) + if err != nil { return nil, fmt.Errorf("parse task params failed: %w", err) } - - // 构建基础请求 - req := &SubmitTaskRequest{ - ReqKey: job.ReqKey, - Prompt: job.Prompt, - } - - // 根据任务类型设置特定参数 - switch job.Type { - case types.JMTaskTypeImage: - s.setTextToImageParams(req, params) - case types.JMTaskTypeVideo: - s.setImageToImageParams(req, params) - case types.JMTaskTypeVirtualHuman: - s.setImageEditParams(req, params) - case types.JMTaskTypeActionTransfer: - s.setImageEffectsParams(req, params) - default: - return nil, fmt.Errorf("unsupported task type: %s", job.Type) - } - - return req, nil -} - -// setTextToImageParams 设置文生图参数 -func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) { - if seed, ok := params["seed"]; ok { - if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { - req.Seed = seedVal - } - } - if scale, ok := params["scale"]; ok { - if scaleVal, ok := scale.(float64); ok { - req.Scale = scaleVal - } - } - if width, ok := params["width"]; ok { - if widthVal, ok := width.(float64); ok { - req.Width = int(widthVal) - } - } - if height, ok := params["height"]; ok { - if heightVal, ok := height.(float64); ok { - req.Height = int(heightVal) - } - } - if usePreLlm, ok := params["use_pre_llm"]; ok { - if usePreLlmVal, ok := usePreLlm.(bool); ok { - req.UsePreLLM = usePreLlmVal - } - } -} - -// setImageToImageParams 设置图生图参数 -func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) { - if imageInput, ok := params["image_input"].(string); ok { - req.ImageInput = imageInput - } - if gpen, ok := params["gpen"]; ok { - if gpenVal, ok := gpen.(float64); ok { - req.Gpen = gpenVal - } - } - if skin, ok := params["skin"]; ok { - if skinVal, ok := skin.(float64); ok { - req.Skin = skinVal - } - } - if skinUnifi, ok := params["skin_unifi"]; ok { - if skinUnifiVal, ok := skinUnifi.(float64); ok { - req.SkinUnifi = skinUnifiVal - } - } - if genMode, ok := params["gen_mode"].(string); ok { - req.GenMode = genMode - } - s.setCommonParams(req, params) // 复用通用参数 -} - -// setImageEditParams 设置图像编辑参数 -func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) { - if imageUrls, ok := params["image_urls"].([]any); ok { - for _, url := range imageUrls { - if urlStr, ok := url.(string); ok { - req.ImageUrls = append(req.ImageUrls, urlStr) + // 把 size 转成 width 和 height + if size, ok := params["size"]; ok { + if sizeStr, ok := size.(string); ok { + if strings.Contains(sizeStr, "x") { + sizes := strings.Split(sizeStr, "x") + params["width"] = sizes[0] + params["height"] = sizes[1] } } + delete(params, "size") } - if binaryData, ok := params["binary_data_base64"].([]any); ok { - for _, data := range binaryData { - if dataStr, ok := data.(string); ok { - req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr) - } - } - } - if scale, ok := params["scale"]; ok { - if scaleVal, ok := scale.(float64); ok { - req.Scale = scaleVal - } - } - s.setCommonParams(req, params) -} -// setImageEffectsParams 设置图像特效参数 -func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) { - if imageInput1, ok := params["image_input1"].(string); ok { - req.ImageInput1 = imageInput1 - } - if templateId, ok := params["template_id"].(string); ok { - req.TemplateId = templateId - } - if width, ok := params["width"]; ok { - if widthVal, ok := width.(float64); ok { - req.Width = int(widthVal) + // duration 转成 frames + if duration, ok := params["duration"]; ok { + if secs, ok := duration.(int); ok { + params["frames"] = secs*24 + 1 } + delete(params, "duration") } - if height, ok := params["height"]; ok { - if heightVal, ok := height.(float64); ok { - req.Height = int(heightVal) - } - } -} -// setTextToVideoParams 设置文生视频参数 -func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) { - if aspectRatio, ok := params["aspect_ratio"].(string); ok { - req.AspectRatio = aspectRatio - } - s.setCommonParams(req, params) -} - -// setImageToVideoParams 设置图生视频参数 -func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) { - s.setImageEditParams(req, params) // 复用图像编辑的参数设置 - if aspectRatio, ok := params["aspect_ratio"].(string); ok { - req.AspectRatio = aspectRatio - } -} - -// setCommonParams 设置通用参数(seed, width, height等) -func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) { - if seed, ok := params["seed"]; ok { - if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { - req.Seed = seedVal - } - } - if width, ok := params["width"]; ok { - if widthVal, ok := width.(float64); ok { - req.Width = int(widthVal) - } - } - if height, ok := params["height"]; ok { - if heightVal, ok := height.(float64); ok { - req.Height = int(heightVal) - } - } + // 删除多余参数,剩下的就是各个任务自己专有参数了 + delete(params, "type") + delete(params, "power") + return params, nil } // pollTaskStatus 轮询任务状态 @@ -368,6 +270,11 @@ func (s *Service) pollTaskStatus() { continue } + // 豆包生图 4.0 是同步任务,不需要轮询 + if job.ReqKey == DoubaoSeedream40ReqKey { + continue + } + // 查询任务状态 resp, err := s.client.QueryTask(&QueryTaskRequest{ ReqKey: job.ReqKey, @@ -384,7 +291,7 @@ func (s *Service) pollTaskStatus() { rawData, _ := json.Marshal(resp) s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData)) - if resp.Code != 10000 { + if resp.Code != CodeSuccess { s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message)) continue } diff --git a/api/service/jimeng/types.go b/api/service/jimeng/types.go index 8b2286f2..5b1acd1e 100644 --- a/api/service/jimeng/types.go +++ b/api/service/jimeng/types.go @@ -4,32 +4,6 @@ import ( "geekai/core/types" ) -// SubmitTaskRequest 提交任务请求 -type SubmitTaskRequest struct { - ReqKey string `json:"req_key"` - // 文生图参数 - Prompt string `json:"prompt,omitempty"` - Seed int64 `json:"seed,omitempty"` - Scale float64 `json:"scale,omitempty"` - Width int `json:"width,omitempty"` - Height int `json:"height,omitempty"` - UsePreLLM bool `json:"use_pre_llm,omitempty"` - // 图生图参数 - ImageInput string `json:"image_input,omitempty"` - ImageUrls []string `json:"image_urls,omitempty"` - BinaryDataBase64 []string `json:"binary_data_base64,omitempty"` - Gpen float64 `json:"gpen,omitempty"` - Skin float64 `json:"skin,omitempty"` - SkinUnifi float64 `json:"skin_unifi,omitempty"` - GenMode string `json:"gen_mode,omitempty"` - // 图像编辑参数 - // 图像特效参数 - ImageInput1 string `json:"image_input1,omitempty"` - TemplateId string `json:"template_id,omitempty"` - // 视频生成参数 - AspectRatio string `json:"aspect_ratio,omitempty"` -} - // SubmitTaskResponse 提交任务响应 type SubmitTaskResponse struct { Code int `json:"code"` @@ -75,6 +49,8 @@ type QueryTaskResponse struct { } `json:"data"` } +const CodeSuccess = 10000 + // CreateTaskRequest 创建任务请求 type CreateTaskRequest struct { Type types.JMTaskType `json:"type"` @@ -84,3 +60,8 @@ type CreateTaskRequest struct { ImageUrls []string `json:"image_urls,omitempty"` Power int `json:"power,omitempty"` } + +const ( + ImageEffectReqKey = "i2i_multi_style_zx2x" + DoubaoSeedream40ReqKey = "doubao-seedream-4-0-250828" +) diff --git a/api/utils/media_duration.go b/api/utils/media_duration.go new file mode 100644 index 00000000..a53d2be8 --- /dev/null +++ b/api/utils/media_duration.go @@ -0,0 +1,817 @@ +package utils + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net/http" + "os" + "time" +) + +// AudioDuration returns duration of an audio file. +// Supported formats: MP3, WAV (auto-detected by header) +func AudioDuration(path string) (time.Duration, error) { + f, err := os.Open(path) + if err != nil { + return 0, err + } + defer f.Close() + + // Peek first 12 bytes to detect format + head := make([]byte, 12) + n, err := io.ReadFull(f, head) + if err != nil { + return 0, err + } + if n < 12 { + return 0, errors.New("file too small") + } + + // WAV: RIFF....WAVE + if string(head[0:4]) == "RIFF" && string(head[8:12]) == "WAVE" { + if _, err := f.Seek(0, io.SeekStart); err != nil { + return 0, err + } + return wavDuration(f) + } + + // MP3 can start with ID3 or frame sync 0xFFEx + if string(head[0:3]) == "ID3" || (head[0] == 0xFF && (head[1]&0xE0) == 0xE0) { + if _, err := f.Seek(0, io.SeekStart); err != nil { + return 0, err + } + return mp3Duration(f) + } + + return 0, errors.New("unsupported audio format") +} + +// AudioDurationFromURL downloads the url to a temp file and returns duration. +func AudioDurationFromURL(url string) (time.Duration, error) { + path, err := fetchURLToTemp(url, 30*time.Second) + if err != nil { + return 0, err + } + defer os.Remove(path) + return AudioDuration(path) +} + +// VideoDurationMP4 returns duration of an MP4 file (MOV/MP4 base media). +func VideoDurationMP4(path string) (time.Duration, error) { + f, err := os.Open(path) + if err != nil { + return 0, err + } + defer f.Close() + + return mp4Duration(f) +} + +// VideoDurationMP4FromURL downloads the url to a temp file and returns duration. +func VideoDurationMP4FromURL(url string) (time.Duration, error) { + path, err := fetchURLToTemp(url, 30*time.Second) + if err != nil { + return 0, err + } + defer os.Remove(path) + return VideoDurationMP4(path) +} + +// ---------------------- helpers ---------------------- + +func fetchURLToTemp(url string, timeout time.Duration) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("http status: %d", resp.StatusCode) + } + + tmp, err := os.CreateTemp("", "media-*") + if err != nil { + return "", err + } + defer tmp.Close() + + if _, err := io.Copy(tmp, resp.Body); err != nil { + path := tmp.Name() + _ = os.Remove(path) + return "", err + } + return tmp.Name(), nil +} + +// ---------------------- WAV ---------------------- + +func wavDuration(r io.ReadSeeker) (time.Duration, error) { + // RIFF header already checked outside if needed. We parse chunks to get fmt and data. + // WAV little-endian + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, err + } + + // Read RIFF header (12 bytes) + head := make([]byte, 12) + if _, err := io.ReadFull(r, head); err != nil { + return 0, err + } + if string(head[0:4]) != "RIFF" || string(head[8:12]) != "WAVE" { + return 0, errors.New("invalid wav header") + } + + var sampleRate uint32 + var numChans uint16 + var bitsPerSample uint16 + var byteRate uint32 + var dataSize uint32 + + for { + chunkHdr := make([]byte, 8) + if _, err := io.ReadFull(r, chunkHdr); err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + return 0, err + } + ckID := string(chunkHdr[0:4]) + ckSize := binary.LittleEndian.Uint32(chunkHdr[4:8]) + + switch ckID { + case "fmt ": + fmtData := make([]byte, ckSize) + if _, err := io.ReadFull(r, fmtData); err != nil { + return 0, err + } + // audioFormat := binary.LittleEndian.Uint16(fmtData[0:2]) // 1 = PCM + numChans = binary.LittleEndian.Uint16(fmtData[2:4]) + sampleRate = binary.LittleEndian.Uint32(fmtData[4:8]) + byteRate = binary.LittleEndian.Uint32(fmtData[8:12]) + // blockAlign := binary.LittleEndian.Uint16(fmtData[12:14]) + if len(fmtData) >= 16 { + bitsPerSample = binary.LittleEndian.Uint16(fmtData[14:16]) + } + case "data": + dataSize = ckSize + // Skip data content + if _, err := r.Seek(int64(ckSize), io.SeekCurrent); err != nil { + return 0, err + } + default: + // Skip other chunks + if _, err := r.Seek(int64(ckSize), io.SeekCurrent); err != nil { + return 0, err + } + } + // Chunks are word-aligned (pad byte if odd size) + if ckSize%2 == 1 { + if _, err := r.Seek(1, io.SeekCurrent); err != nil { // skip pad + return 0, err + } + } + } + + if sampleRate == 0 || numChans == 0 { + return 0, errors.New("invalid wav fmt") + } + + var durationSeconds float64 + if byteRate != 0 { + durationSeconds = float64(dataSize) / float64(byteRate) + } else { + bytesPerSec := float64(sampleRate) * float64(numChans) * float64(bitsPerSample) / 8.0 + if bytesPerSec == 0 { + return 0, errors.New("invalid wav parameters") + } + durationSeconds = float64(dataSize) / bytesPerSec + } + return time.Duration(durationSeconds * float64(time.Second)), nil +} + +// ---------------------- MP3 ---------------------- + +func mp3Duration(r io.ReadSeeker) (time.Duration, error) { + // Strategy: + // 1) Skip ID3v2 header if present. + // 2) Try read first frame and detect XING/Info or VBRI to get total frames and duration. + // 3) If VBR headers not present, fall back to CBR estimation: (audioDataBytes * 8) / bitrate. + + // File size + fi, err := fileSizeFromSeeker(r) + if err != nil { + return 0, err + } + + // Skip ID3v2 + var id3v2Size int64 + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, err + } + id3v2Size, err = skipID3v2(r) + if err != nil { + return 0, err + } + + // Remember audio start offset + startOffset, _ := r.Seek(0, io.SeekCurrent) + + // Read first frame header (search sync) + off, fh, err := findNextMP3Frame(r) + if err != nil { + return 0, err + } + if _, err := r.Seek(off, io.SeekStart); err != nil { + return 0, err + } + + // Check for XING/Info header in first frame (for VBR) + totalFrames, sampleRate, samplesPerFrame, bitrateKbps, vbrFound, err := parseFirstFrameForVBR(r, fh) + if err != nil { + return 0, err + } + + if vbrFound && totalFrames > 0 && sampleRate > 0 && samplesPerFrame > 0 { + seconds := (float64(totalFrames) * float64(samplesPerFrame)) / float64(sampleRate) + return time.Duration(seconds * float64(time.Second)), nil + } + + // Fall back to CBR estimate using bitrate and data size (excluding ID3v2 and ID3v1) + // Detect ID3v1 at end (128 bytes TAG) + var id3v1Size int64 + if fi >= 128 { + if _, err := r.Seek(fi-128, io.SeekStart); err == nil { + buf := make([]byte, 3) + if _, err := io.ReadFull(r, buf); err == nil { + if string(buf) == "TAG" { + id3v1Size = 128 + } + } + } + } + + audioBytes := fi - id3v2Size - id3v1Size - startOffset + if audioBytes <= 0 || bitrateKbps == 0 { + return 0, errors.New("unable to estimate mp3 duration") + } + // bitrateKbps in kbps, bytes -> bits + seconds := float64(audioBytes*8) / float64(bitrateKbps*1000) + return time.Duration(seconds * float64(time.Second)), nil +} + +type mp3FrameHeader struct { + Version int // 1: MPEG1, 2: MPEG2, 25: MPEG2.5 + Layer int // 1,2,3 + BitrateKbps int + SampleRate int + Padding int + ChannelMode int // 0:Stereo,1:Joint,2:Dual,3:Mono +} + +func findNextMP3Frame(r io.ReadSeeker) (int64, mp3FrameHeader, error) { + var hdr mp3FrameHeader + // Start from current pos and scan up to 64KB + start, _ := r.Seek(0, io.SeekCurrent) + limit := int64(64 * 1024) + buf := make([]byte, limit) + n, err := r.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + return 0, hdr, err + } + for i := 0; i+4 <= n; i++ { + if buf[i] == 0xFF && (buf[i+1]&0xE0) == 0xE0 { // sync + if h, ok := parseMP3Header(buf[i : i+4]); ok { + offset := start + int64(i) + return offset, h, nil + } + } + } + return 0, hdr, errors.New("mp3 frame not found") +} + +func parseMP3Header(b []byte) (mp3FrameHeader, bool) { + var h mp3FrameHeader + if len(b) < 4 { + return h, false + } + if b[0] != 0xFF || (b[1]&0xE0) != 0xE0 { + return h, false + } + versionBits := (b[1] >> 3) & 0x03 + layerBits := (b[1] >> 1) & 0x03 + bitrateBits := (b[2] >> 4) & 0x0F + sampleRateBits := (b[2] >> 2) & 0x03 + paddingBit := (b[2] >> 1) & 0x01 + channelMode := (b[3] >> 6) & 0x03 + + var version int + switch versionBits { + case 0x00: + version = 25 // MPEG 2.5 + case 0x02: + version = 2 // MPEG 2 + case 0x03: + version = 1 // MPEG 1 + default: + return h, false + } + + var layer int + switch layerBits { + case 0x01: + layer = 3 + case 0x02: + layer = 2 + case 0x03: + layer = 1 + default: + return h, false + } + + br := mp3BitrateKbps(version, layer, int(bitrateBits)) + if br == 0 { + return h, false + } + sr := mp3SampleRate(version, int(sampleRateBits)) + if sr == 0 { + return h, false + } + + h = mp3FrameHeader{ + Version: version, + Layer: layer, + BitrateKbps: br, + SampleRate: sr, + Padding: int(paddingBit), + ChannelMode: int(channelMode), + } + return h, true +} + +func mp3BitrateKbps(version, layer, index int) int { + // index: 1..14 valid; 0,15 invalid + if index <= 0 || index == 15 { + return 0 + } + // Tables per ISO/IEC 11172-3/13818-3 (common subset) + var tbl [15]int + if layer == 1 { // Layer I + if version == 1 { // MPEG1 + tbl = [15]int{0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448} + } else { // MPEG2/2.5 + tbl = [15]int{0, 32, 48, 56, 64, 80, 96, 112, 128, 144, 160, 176, 192, 224, 256} + } + } else if layer == 2 { // Layer II + if version == 1 { + tbl = [15]int{0, 32, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384} + } else { + tbl = [15]int{0, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160} + } + } else { // Layer III + if version == 1 { + tbl = [15]int{0, 32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320} + } else { + tbl = [15]int{0, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160} + } + } + return tbl[index] +} + +func mp3SampleRate(version, index int) int { + if index == 3 { + return 0 + } + // base table for MPEG1 + base := [3]int{44100, 48000, 32000} + sr := base[index] + if version == 2 { // MPEG2 + sr /= 2 + } else if version == 25 { // MPEG2.5 + sr /= 4 + } + return sr +} + +func samplesPerMP3Frame(version, layer int) int { + switch layer { + case 1: + return 384 + case 2: + return 1152 + case 3: + if version == 1 { + return 1152 + } + return 576 // MPEG2/2.5 Layer III + default: + return 0 + } +} + +func parseFirstFrameForVBR(r io.ReadSeeker, fh mp3FrameHeader) (totalFrames uint32, sampleRate int, samplesPerFrame int, bitrateKbps int, vbrFound bool, err error) { + // After the 4-byte header, possible side info and then XING/Info + if _, err = r.Seek(0, io.SeekCurrent); err != nil { + return + } + // Re-read header + hdr := make([]byte, 4) + if _, err = io.ReadFull(r, hdr); err != nil { + return + } + + // side info size depends on MPEG version and channel mode (for Layer III) + sideInfoSize := 0 + if fh.Layer == 3 { // Layer III + if fh.Version == 1 { // MPEG1 + if fh.ChannelMode == 3 { // mono + sideInfoSize = 17 + } else { + sideInfoSize = 32 + } + } else { // MPEG2/2.5 + if fh.ChannelMode == 3 { + sideInfoSize = 9 + } else { + sideInfoSize = 17 + } + } + } + + // Read next up to 120 bytes to search for XING/Info or VBRI + buf := make([]byte, sideInfoSize+120) + if _, err = io.ReadFull(r, buf); err != nil { + // If short, still try within available + if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) { + return + } + } + + // Search XING/Info signature + sigs := [][]byte{[]byte("Xing"), []byte("Info")} + for _, sig := range sigs { + idx := indexOf(buf, sig) + if idx >= 0 { + // flags after signature (4 bytes), then if frames flag set, 4 bytes frames + if len(buf) >= idx+4+4 { + flags := binary.BigEndian.Uint32(buf[idx+4 : idx+8]) + var frames uint32 + if (flags & 0x01) != 0 { // frames present + if len(buf) >= idx+8+4 { + frames = binary.BigEndian.Uint32(buf[idx+8 : idx+12]) + } + } + if frames > 0 { + vbrFound = true + totalFrames = frames + sampleRate = fh.SampleRate + samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer) + bitrateKbps = fh.BitrateKbps + return + } + } + } + } + + // Check VBRI (usually at 32 bytes after header for MPEG1 Layer III) + if len(buf) >= 4 { + idx := indexOf(buf, []byte("VBRI")) + if idx >= 0 { + if len(buf) >= idx+4+2+2+4+4 { + // VBRI layout: 'VBRI'(4) + version(2) + delay(2) + quality(2?) varies; but at offset 10 comes bytes: bytes (4), frames (4) + // Some docs: offset 10: bytes, offset 14: frames (big-endian) + bytesOffset := idx + 10 + framesOffset := idx + 14 + if len(buf) >= framesOffset+4 { + frames := binary.BigEndian.Uint32(buf[framesOffset : framesOffset+4]) + if frames > 0 { + vbrFound = true + totalFrames = frames + sampleRate = fh.SampleRate + samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer) + bitrateKbps = fh.BitrateKbps + _ = bytesOffset // not used + return + } + } + } + } + } + + // No VBR header. Provide header info for CBR fallback + sampleRate = fh.SampleRate + samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer) + bitrateKbps = fh.BitrateKbps + return +} + +func indexOf(haystack []byte, needle []byte) int { + for i := 0; i+len(needle) <= len(haystack); i++ { + match := true + for j := 0; j < len(needle); j++ { + if haystack[i+j] != needle[j] { + match = false + break + } + } + if match { + return i + } + } + return -1 +} + +func skipID3v2(r io.ReadSeeker) (int64, error) { + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, err + } + head := make([]byte, 10) + if _, err := io.ReadFull(r, head); err != nil { + return 0, nil // no header + } + if string(head[0:3]) != "ID3" { + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, err + } + return 0, nil + } + // size: 4 synchsafe bytes + sz := int64((int(head[6]&0x7F) << 21) | (int(head[7]&0x7F) << 14) | (int(head[8]&0x7F) << 7) | int(head[9]&0x7F)) + // total header size = 10 + sz (+ footer 10 if flag set) + footer := int64(0) + if (head[5] & 0x10) != 0 { // footer present + footer = 10 + } + total := 10 + sz + footer + if _, err := r.Seek(total, io.SeekStart); err != nil { + return 0, err + } + return total, nil +} + +func fileSizeFromSeeker(r io.ReadSeeker) (int64, error) { + cur, err := r.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + end, err := r.Seek(0, io.SeekEnd) + if err != nil { + return 0, err + } + if _, err := r.Seek(cur, io.SeekStart); err != nil { + return 0, err + } + return end, nil +} + +// ---------------------- MP4 ---------------------- + +type mp4BoxHeader struct { + Size uint64 + Type [4]byte +} + +func readBoxHeader(r io.ReadSeeker) (mp4BoxHeader, error) { + var h mp4BoxHeader + buf := make([]byte, 8) + if _, err := io.ReadFull(r, buf); err != nil { + return h, err + } + sz := binary.BigEndian.Uint32(buf[0:4]) + copy(h.Type[:], buf[4:8]) + if sz == 1 { + // 64-bit size follows + ext := make([]byte, 8) + if _, err := io.ReadFull(r, ext); err != nil { + return h, err + } + h.Size = binary.BigEndian.Uint64(ext) + } else { + h.Size = uint64(sz) + } + return h, nil +} + +func skipBox(r io.ReadSeeker, boxSize uint64, alreadyRead int64) error { + toSkip := int64(boxSize) - alreadyRead + if toSkip < 0 { + return fmt.Errorf("invalid box size") + } + _, err := r.Seek(toSkip, io.SeekCurrent) + return err +} + +func mp4Duration(r io.ReadSeeker) (time.Duration, error) { + if _, err := r.Seek(0, io.SeekStart); err != nil { + return 0, err + } + // Find moov box + var moovStart int64 + var moovSize uint64 + for { + pos, _ := r.Seek(0, io.SeekCurrent) + h, err := readBoxHeader(r) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + return 0, err + } + if string(h.Type[:]) == "moov" { + moovStart = pos + moovSize = h.Size + break + } + if h.Size < 8 { + return 0, errors.New("invalid mp4 box size") + } + if err := skipBox(r, h.Size, 8); err != nil { + return 0, err + } + } + if moovStart == 0 && moovSize == 0 { + return 0, errors.New("moov not found") + } + // Parse inside moov for video trak mdhd, else mvhd + if _, err := r.Seek(moovStart+8, io.SeekStart); err != nil { // skip moov header + return 0, err + } + end := moovStart + int64(moovSize) + var movieTimescale uint32 + var movieDuration uint64 + var foundVideoMdhd bool + var mdhdTimescale uint32 + var mdhdDuration uint64 + + for { + pos, _ := r.Seek(0, io.SeekCurrent) + if pos >= end { + break + } + h, err := readBoxHeader(r) + if err != nil { + return 0, err + } + switch string(h.Type[:]) { + case "mvhd": + // movie header + ver := make([]byte, 1) + if _, err := io.ReadFull(r, ver); err != nil { + return 0, err + } + if _, err := r.Seek(3, io.SeekCurrent); err != nil { // flags + return 0, err + } + if ver[0] == 1 { + // version 1: 64-bit duration + buf := make([]byte, 8+8+4+8) // ctime(8) mtime(8) timescale(4) duration(8) + if _, err := io.ReadFull(r, buf); err != nil { + return 0, err + } + movieTimescale = binary.BigEndian.Uint32(buf[16:20]) + movieDuration = binary.BigEndian.Uint64(buf[20:28]) + } else { + buf := make([]byte, 4+4+4+4) // ctime mtime timescale duration + if _, err := io.ReadFull(r, buf); err != nil { + return 0, err + } + movieTimescale = binary.BigEndian.Uint32(buf[8:12]) + movieDuration = uint64(binary.BigEndian.Uint32(buf[12:16])) + } + // skip rest of mvhd + read := int64(1 + 3) + if ver[0] == 1 { + read += int64(8 + 8 + 4 + 8) + } else { + read += int64(4 + 4 + 4 + 4) + } + if err := skipBox(r, h.Size, 8+read); err != nil { + return 0, err + } + case "trak": + // parse trak for hdlr and mdhd + tEnd := int64(0) + if h.Size < 8 { + return 0, errors.New("invalid trak size") + } + tEnd = pos + int64(h.Size) + var isVideo bool + var tMdhdTimescale uint32 + var tMdhdDuration uint64 + for { + cpos, _ := r.Seek(0, io.SeekCurrent) + if cpos >= tEnd { + break + } + ch, err := readBoxHeader(r) + if err != nil { + return 0, err + } + switch string(ch.Type[:]) { + case "mdia": + mEnd := cpos + int64(ch.Size) + for { + mpos, _ := r.Seek(0, io.SeekCurrent) + if mpos >= mEnd { + break + } + mh, err := readBoxHeader(r) + if err != nil { + return 0, err + } + switch string(mh.Type[:]) { + case "hdlr": + // skip version+flags (4), pre_defined(4) + if _, err := r.Seek(8, io.SeekCurrent); err != nil { + return 0, err + } + handler := make([]byte, 4) + if _, err := io.ReadFull(r, handler); err != nil { + return 0, err + } + if string(handler) == "vide" { + isVideo = true + } + if err := skipBox(r, mh.Size, 8+8+4); err != nil { // header + skipped + read handler + return 0, err + } + case "mdhd": + ver := make([]byte, 1) + if _, err := io.ReadFull(r, ver); err != nil { + return 0, err + } + if _, err := r.Seek(3, io.SeekCurrent); err != nil { // flags + return 0, err + } + if ver[0] == 1 { + buf := make([]byte, 8+8+4+8) + if _, err := io.ReadFull(r, buf); err != nil { + return 0, err + } + tMdhdTimescale = binary.BigEndian.Uint32(buf[16:20]) + tMdhdDuration = binary.BigEndian.Uint64(buf[20:28]) + } else { + buf := make([]byte, 4+4+4+4) + if _, err := io.ReadFull(r, buf); err != nil { + return 0, err + } + tMdhdTimescale = binary.BigEndian.Uint32(buf[8:12]) + tMdhdDuration = uint64(binary.BigEndian.Uint32(buf[12:16])) + } + if err := skipBox(r, mh.Size, 8+1+3+int64(lenVersionPayload(ver[0]))); err != nil { + return 0, err + } + default: + if err := skipBox(r, mh.Size, 8); err != nil { + return 0, err + } + } + } + default: + if err := skipBox(r, ch.Size, 8); err != nil { + return 0, err + } + } + } + if isVideo && tMdhdTimescale != 0 && tMdhdDuration != 0 { + foundVideoMdhd = true + mdhdTimescale = tMdhdTimescale + mdhdDuration = tMdhdDuration + } + // Skip remaining of trak if any + if _, err := r.Seek(tEnd, io.SeekStart); err != nil { + return 0, err + } + default: + if err := skipBox(r, h.Size, 8); err != nil { + return 0, err + } + } + } + + if foundVideoMdhd && mdhdTimescale != 0 { + sec := float64(mdhdDuration) / float64(mdhdTimescale) + return time.Duration(sec * float64(time.Second)), nil + } + if movieTimescale != 0 { + sec := float64(movieDuration) / float64(movieTimescale) + return time.Duration(sec * float64(time.Second)), nil + } + return 0, errors.New("failed to read mp4 duration") +} + +func lenVersionPayload(ver byte) int { + if ver == 1 { + return 8 + 8 + 4 + 8 + } + return 4 + 4 + 4 + 4 +} diff --git a/web/src/components/ParamBuilder.vue b/web/src/components/ParamBuilder.vue index be01fe94..d6d914a4 100644 --- a/web/src/components/ParamBuilder.vue +++ b/web/src/components/ParamBuilder.vue @@ -225,7 +225,7 @@ const initModelValue = (model) => { } }) } - defaultValues.model = selectedModel.value.key + defaultValues.req_key = selectedModel.value.key return defaultValues } diff --git a/web/src/components/ui/Alert.vue b/web/src/components/ui/Alert.vue index cb187cc7..c7026e8a 100644 --- a/web/src/components/ui/Alert.vue +++ b/web/src/components/ui/Alert.vue @@ -23,7 +23,7 @@ const props = defineProps({ const typeClass = computed(() => { return { - info: 'bg-blue-100 text-blue-500 border-blue-500', + info: 'bg-blue-100 text-white-500 border-blue-500', success: 'bg-green-100 text-green-500 border-green-500', warning: 'bg-yellow-100 text-yellow-500 border-yellow-500', }[props.type] diff --git a/web/src/store/data/jimeng_data.js b/web/src/store/data/jimeng_data.js index 3288ae0f..562cc2b7 100644 --- a/web/src/store/data/jimeng_data.js +++ b/web/src/store/data/jimeng_data.js @@ -537,6 +537,22 @@ export const JimengParams = { label: '21:9 (2016 * 864)', value: '2016x864', }, + { + label: '9:16 (936 * 1664)', + value: '936x1664', + }, + { + label: '2:3 (1056 * 1584)', + value: '1056x1584', + }, + { + label: '3:4 (1104 * 1472)', + value: '1104x1472', + }, + { + label: '9:21 (864 * 2016)', + value: '864x2016', + }, ], }, ], diff --git a/web/src/store/jimeng.js b/web/src/store/jimeng.js index 0ac9c1ac..13239cb1 100644 --- a/web/src/store/jimeng.js +++ b/web/src/store/jimeng.js @@ -195,14 +195,19 @@ export const useJimengStore = defineStore('jimeng', () => { try { submitting.value = true - - const response = await httpPost('/api/jimeng/task', formData.value) - if (response.data) { - showMessageOK('任务提交成功') - isOver.value = false - await fetchData(1) - startPolling() + formData.value.type = activeFunction.value + // 视频 duration 转成整数 + if (formData.value.duration) { + formData.value.duration = parseInt(formData.value.duration) } + if (formData.value.image_urls && !Array.isArray(formData.value.image_urls)) { + formData.value.image_urls = [formData.value.image_urls] + } + const response = await httpPost('/api/jimeng/task', formData.value) + showMessageOK('任务提交成功') + isOver.value = false + await fetchData(1) + startPolling() } catch (error) { console.error('提交任务失败:', error) showMessageError(error.message || '提交任务失败') diff --git a/web/src/views/admin/jimeng/JimengConfig.vue b/web/src/views/admin/jimeng/JimengConfig.vue index 99cf5e7c..b8698d61 100644 --- a/web/src/views/admin/jimeng/JimengConfig.vue +++ b/web/src/views/admin/jimeng/JimengConfig.vue @@ -24,10 +24,15 @@ > 和 智能绘图以及火山方舟 服务。

@@ -41,6 +46,17 @@ > 获取。

+

+ 3. ApiKey 请在火山方舟控制台 -> + + API Key管理 + 获取。 +

@@ -49,6 +65,19 @@ + + + +
+ 目前豆包生图 4.0 模型在即梦API中不支持,需要使用火山方舟服务。 +
+