mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-05-11 04:04:29 +08:00
acommpelish jimeng AI refactor for PC
This commit is contained in:
@@ -2,9 +2,13 @@ package types
|
|||||||
|
|
||||||
// JimengConfig 即梦AI配置
|
// JimengConfig 即梦AI配置
|
||||||
type JimengConfig struct {
|
type JimengConfig struct {
|
||||||
AccessKey string `json:"access_key"`
|
// 即梦AI的AccessKey和SecretKey
|
||||||
SecretKey string `json:"secret_key"`
|
AccessKey string `json:"access_key"`
|
||||||
Power JimengPower `json:"power"`
|
SecretKey string `json:"secret_key"`
|
||||||
|
// 火山引擎大模型专用的验证方式
|
||||||
|
ApiKey string `json:"api_key"`
|
||||||
|
// 算力配置
|
||||||
|
Power JimengPower `json:"power"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// JimengPower 即梦AI算力配置
|
// JimengPower 即梦AI算力配置
|
||||||
@@ -40,7 +44,9 @@ const (
|
|||||||
|
|
||||||
// JimengTaskRequest 即梦AI任务请求
|
// JimengTaskRequest 即梦AI任务请求
|
||||||
type JimengTaskRequest struct {
|
type JimengTaskRequest struct {
|
||||||
ReqKey string `json:"req_key"` // 请求Key
|
TaskType JMTaskType `json:"type"` // 任务类型
|
||||||
|
ReqKey string `json:"req_key"` // 请求Key
|
||||||
|
Power int `json:"power"` // 消耗算力
|
||||||
// 公共参数
|
// 公共参数
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
ImageUrls []string `json:"image_urls,omitempty"`
|
ImageUrls []string `json:"image_urls,omitempty"`
|
||||||
@@ -50,7 +56,7 @@ type JimengTaskRequest struct {
|
|||||||
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
||||||
|
|
||||||
// 视频生成参数
|
// 视频生成参数
|
||||||
Duration string `json:"duration,omitempty"` // 视频时长
|
Duration int `json:"duration,omitempty"` // 视频时长,单位:秒
|
||||||
TemplateId string `json:"template_id,omitempty"` // 运镜模板ID
|
TemplateId string `json:"template_id,omitempty"` // 运镜模板ID
|
||||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||||
CameraStrength string `json:"camera_strength,omitempty"` // 运镜强度
|
CameraStrength string `json:"camera_strength,omitempty"` // 运镜强度
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ require (
|
|||||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||||
github.com/shopspring/decimal v1.3.1
|
github.com/shopspring/decimal v1.3.1
|
||||||
github.com/syndtr/goleveldb v1.0.0
|
github.com/syndtr/goleveldb v1.0.0
|
||||||
|
github.com/volcengine/volcengine-go-sdk v1.1.34
|
||||||
golang.org/x/image v0.15.0
|
golang.org/x/image v0.15.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,6 +51,7 @@ require (
|
|||||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
go.uber.org/mock v0.4.0 // indirect
|
go.uber.org/mock v0.4.0 // indirect
|
||||||
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W
|
|||||||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
||||||
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
||||||
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||||
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
@@ -110,6 +111,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
|||||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
|
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
|
||||||
@@ -259,6 +261,8 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
|
|||||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
||||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||||
|
github.com/volcengine/volcengine-go-sdk v1.1.34 h1:ha90JycCCTJNCse0UDziBgBsuX98ITOrkwYlDWcm7NI=
|
||||||
|
github.com/volcengine/volcengine-go-sdk v1.1.34/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
@@ -390,6 +394,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2
|
|||||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||||
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
|
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core"
|
"geekai/core"
|
||||||
"geekai/core/middleware"
|
"geekai/core/middleware"
|
||||||
@@ -94,32 +95,29 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取算力消耗
|
// 获取算力消耗
|
||||||
|
powerCost, err := h.getTaskPower(req)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "计算任务消耗积分失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// if user.Power < powerCost {
|
if user.Power < powerCost {
|
||||||
// resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||||
// return
|
return
|
||||||
// }
|
}
|
||||||
|
|
||||||
// taskReq := &jimeng.CreateTaskRequest{
|
job, err := h.jimengService.CreateTask(user.Id, &req)
|
||||||
// Type: taskType,
|
if err != nil {
|
||||||
// Prompt: req.Prompt,
|
logger.Errorf("create jimeng task failed: %v", err)
|
||||||
// Params: params,
|
resp.ERROR(c, "创建任务失败")
|
||||||
// ReqKey: reqKey,
|
return
|
||||||
// Power: powerCost,
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
// job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
|
||||||
// if err != nil {
|
Type: types.PowerConsume,
|
||||||
// logger.Errorf("create jimeng task failed: %v", err)
|
Model: job.ReqKey,
|
||||||
// resp.ERROR(c, "创建任务失败")
|
Remark: fmt.Sprintf("%s,任务ID:%d", req.ReqKey, job.Id),
|
||||||
// return
|
})
|
||||||
// }
|
|
||||||
|
|
||||||
// h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
|
|
||||||
// Type: types.PowerConsume,
|
|
||||||
// Model: "jimeng",
|
|
||||||
// Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id),
|
|
||||||
// })
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
@@ -224,7 +222,7 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
|||||||
if job.Status != types.JMTaskStatusFailed {
|
if job.Status != types.JMTaskStatusFailed {
|
||||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||||
Type: types.PowerRefund,
|
Type: types.PowerRefund,
|
||||||
Model: "jimeng",
|
Model: job.ReqKey,
|
||||||
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -285,20 +283,24 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||||
func (h *JimengHandler) getPowerFromConfig(taskType types.JMTaskType) int {
|
func (h *JimengHandler) getTaskPower(req types.JimengTaskRequest) (int, error) {
|
||||||
config := h.App.SysConfig.Jimeng
|
config := h.App.SysConfig.Jimeng
|
||||||
|
switch req.TaskType {
|
||||||
switch taskType {
|
|
||||||
case types.JMTaskTypeImage:
|
case types.JMTaskTypeImage:
|
||||||
return config.Power.Image
|
return config.Power.Image, nil
|
||||||
case types.JMTaskTypeVideo:
|
case types.JMTaskTypeVideo:
|
||||||
return config.Power.Video
|
if req.Duration == 0 {
|
||||||
|
return 0, errors.New("视频时长不能为0")
|
||||||
|
}
|
||||||
|
return config.Power.Video * req.Duration, nil
|
||||||
case types.JMTaskTypeVirtualHuman:
|
case types.JMTaskTypeVirtualHuman:
|
||||||
return config.Power.VirtualHuman
|
// TODO 计算音频时长
|
||||||
|
return config.Power.VirtualHuman, nil
|
||||||
case types.JMTaskTypeActionTransfer:
|
case types.JMTaskTypeActionTransfer:
|
||||||
return config.Power.ActionTransfer
|
// TODO 计算视频时长
|
||||||
|
return config.Power.ActionTransfer, nil
|
||||||
default:
|
default:
|
||||||
return 10
|
return 0, errors.New("任务类型不支持")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,9 +308,9 @@ func (h *JimengHandler) getPowerFromConfig(taskType types.JMTaskType) int {
|
|||||||
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
||||||
config := h.App.SysConfig.Jimeng
|
config := h.App.SysConfig.Jimeng
|
||||||
resp.SUCCESS(c, gin.H{
|
resp.SUCCESS(c, gin.H{
|
||||||
"image": config.Power.Image,
|
"image": config.Power.Image,
|
||||||
"video": config.Power.Video,
|
"video": config.Power.Video,
|
||||||
"image_edit": config.Power.VirtualHuman,
|
"virtual_human": config.Power.VirtualHuman,
|
||||||
"image_effects": config.Power.ActionTransfer,
|
"action_transfer": config.Power.ActionTransfer,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package jimeng
|
package jimeng
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
@@ -10,6 +11,9 @@ import (
|
|||||||
|
|
||||||
"github.com/volcengine/volc-sdk-golang/base"
|
"github.com/volcengine/volc-sdk-golang/base"
|
||||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||||
|
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client 即梦API客户端
|
// Client 即梦API客户端
|
||||||
@@ -94,13 +98,19 @@ func (c *Client) testConnection() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SubmitTask 提交异步任务
|
// SubmitTask 提交异步任务
|
||||||
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
|
func (c *Client) SubmitTask(req map[string]any) (*SubmitTaskResponse, error) {
|
||||||
// 直接将请求转为map[string]interface{}
|
// 直接将请求转为map[string]interface{}
|
||||||
reqBodyBytes, err := json.Marshal(req)
|
reqBodyBytes, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 单独处理图片特效任务
|
||||||
|
if req["req_key"] == ImageEffectReqKey {
|
||||||
|
req["image_input1"] = req["image_urls"].([]any)[0]
|
||||||
|
delete(req, "image_urls")
|
||||||
|
}
|
||||||
|
|
||||||
// 直接使用序列化后的字节
|
// 直接使用序列化后的字节
|
||||||
jsonBody := reqBodyBytes
|
jsonBody := reqBodyBytes
|
||||||
|
|
||||||
@@ -146,27 +156,29 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
|||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubmitSyncTask 提交同步任务(仅用于文生图)
|
// SubmitSyncImageTask 提交同步生图任务
|
||||||
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
|
func (c *Client) SubmitSyncImageTask(req types.JimengTaskRequest) (*model.ImagesResponse, error) {
|
||||||
// 序列化请求
|
// 配置火山引擎访问密钥,目前只支持API Key验证
|
||||||
jsonBody, err := json.Marshal(req)
|
client := arkruntime.NewClientWithApiKey(c.config.ApiKey)
|
||||||
|
// 构造生图请求
|
||||||
|
sequentialImageGeneration := model.SequentialImageGeneration("disabled")
|
||||||
|
generateReq := model.GenerateImagesRequest{
|
||||||
|
Model: req.ReqKey, // 模型名称
|
||||||
|
Prompt: req.Prompt, // 提示词
|
||||||
|
Size: volcengine.String(req.Size), // 图片尺寸
|
||||||
|
SequentialImageGeneration: &sequentialImageGeneration, // 禁用序列生成
|
||||||
|
ResponseFormat: volcengine.String(model.GenerateImagesResponseFormatURL), // 响应格式为 URL
|
||||||
|
Watermark: volcengine.Bool(false), // 不添加水印
|
||||||
|
OptimizePrompt: volcengine.Bool(true), // 优化提示词
|
||||||
|
}
|
||||||
|
if len(req.ImageUrls) > 0 {
|
||||||
|
generateReq.Image = req.ImageUrls
|
||||||
|
}
|
||||||
|
// 调用生图 API
|
||||||
|
resp, err := client.GenerateImages(context.Background(), generateReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调用SDK的JSON方法
|
return &resp, nil
|
||||||
respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
|
||||||
|
|
||||||
// 解析响应,同步任务直接返回结果
|
|
||||||
var result QueryTaskResponse
|
|
||||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &result, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -103,24 +103,18 @@ func (s *Service) processNextTask() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateTask 创建任务
|
// CreateTask 创建任务
|
||||||
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
|
func (s *Service) CreateTask(userId uint, req *types.JimengTaskRequest) (*model.JimengJob, error) {
|
||||||
// 生成任务ID
|
// 生成任务ID
|
||||||
taskId := utils.RandString(20)
|
taskId := utils.RandString(20)
|
||||||
|
|
||||||
// 序列化任务参数
|
|
||||||
paramsJson, err := json.Marshal(req.Params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshal task params failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建任务记录
|
// 创建任务记录
|
||||||
job := &model.JimengJob{
|
job := &model.JimengJob{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
TaskId: taskId,
|
TaskId: taskId,
|
||||||
Type: req.Type,
|
Type: req.TaskType,
|
||||||
ReqKey: req.ReqKey,
|
ReqKey: req.ReqKey,
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Params: string(paramsJson),
|
Params: utils.JsonEncode(req),
|
||||||
Status: types.JMTaskStatusInQueue,
|
Status: types.JMTaskStatusInQueue,
|
||||||
Power: req.Power,
|
Power: req.Power,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -153,21 +147,61 @@ func (s *Service) ProcessTask(jobId uint) error {
|
|||||||
return fmt.Errorf("update job status failed: %w", err)
|
return fmt.Errorf("update job status failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 解析任务参数
|
||||||
|
var req types.JimengTaskRequest
|
||||||
|
err := utils.JsonDecode(job.Params, &req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse task params failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// 构建请求并提交任务
|
// 构建请求并提交任务
|
||||||
req, err := s.buildTaskRequest(&job)
|
params, err := s.buildTaskRequest(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("提交即梦任务: %+v", req)
|
logger.Debugf("提交即梦任务: %+v", params)
|
||||||
|
|
||||||
// 提交异步任务
|
// 同步任务 ,后台执行
|
||||||
resp, err := s.client.SubmitTask(req)
|
if req.ReqKey == DoubaoSeedream40ReqKey {
|
||||||
|
go func() {
|
||||||
|
resp, err := s.client.SubmitSyncImageTask(req)
|
||||||
|
if err != nil {
|
||||||
|
_ = s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Infof("同步任务提交成功: %+v", resp)
|
||||||
|
// 更新原始数据
|
||||||
|
rawData, _ := json.Marshal(resp)
|
||||||
|
updates := map[string]any{
|
||||||
|
"raw_data": string(rawData),
|
||||||
|
}
|
||||||
|
if resp.Error != nil {
|
||||||
|
updates["status"] = types.JMTaskStatusFailed
|
||||||
|
updates["err_msg"] = resp.Error.Message
|
||||||
|
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新任务状态
|
||||||
|
updates["status"] = types.JMTaskStatusSuccess
|
||||||
|
// 下载图片
|
||||||
|
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(*resp.Data[0].Url, ".png", false)
|
||||||
|
if err == nil {
|
||||||
|
updates["img_url"] = imgUrl
|
||||||
|
}
|
||||||
|
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步任务 ,前台执行
|
||||||
|
resp, err := s.client.SubmitTask(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Code != 10000 {
|
if resp.Code != CodeSuccess {
|
||||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,168 +219,36 @@ func (s *Service) ProcessTask(jobId uint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildTaskRequest 构建任务请求(统一的参数解析)
|
// buildTaskRequest 构建任务请求(统一的参数解析)
|
||||||
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
|
func (s *Service) buildTaskRequest(req *types.JimengTaskRequest) (map[string]any, error) {
|
||||||
// 解析任务参数
|
|
||||||
var params map[string]any
|
var params map[string]any
|
||||||
if err := json.Unmarshal([]byte(job.Params), ¶ms); err != nil {
|
err := utils.JsonDecode(utils.JsonEncode(req), ¶ms)
|
||||||
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse task params failed: %w", err)
|
return nil, fmt.Errorf("parse task params failed: %w", err)
|
||||||
}
|
}
|
||||||
|
// 把 size 转成 width 和 height
|
||||||
// 构建基础请求
|
if size, ok := params["size"]; ok {
|
||||||
req := &SubmitTaskRequest{
|
if sizeStr, ok := size.(string); ok {
|
||||||
ReqKey: job.ReqKey,
|
if strings.Contains(sizeStr, "x") {
|
||||||
Prompt: job.Prompt,
|
sizes := strings.Split(sizeStr, "x")
|
||||||
}
|
params["width"] = sizes[0]
|
||||||
|
params["height"] = sizes[1]
|
||||||
// 根据任务类型设置特定参数
|
|
||||||
switch job.Type {
|
|
||||||
case types.JMTaskTypeImage:
|
|
||||||
s.setTextToImageParams(req, params)
|
|
||||||
case types.JMTaskTypeVideo:
|
|
||||||
s.setImageToImageParams(req, params)
|
|
||||||
case types.JMTaskTypeVirtualHuman:
|
|
||||||
s.setImageEditParams(req, params)
|
|
||||||
case types.JMTaskTypeActionTransfer:
|
|
||||||
s.setImageEffectsParams(req, params)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setTextToImageParams 设置文生图参数
|
|
||||||
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
|
||||||
if seed, ok := params["seed"]; ok {
|
|
||||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
|
||||||
req.Seed = seedVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if scale, ok := params["scale"]; ok {
|
|
||||||
if scaleVal, ok := scale.(float64); ok {
|
|
||||||
req.Scale = scaleVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if width, ok := params["width"]; ok {
|
|
||||||
if widthVal, ok := width.(float64); ok {
|
|
||||||
req.Width = int(widthVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if height, ok := params["height"]; ok {
|
|
||||||
if heightVal, ok := height.(float64); ok {
|
|
||||||
req.Height = int(heightVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if usePreLlm, ok := params["use_pre_llm"]; ok {
|
|
||||||
if usePreLlmVal, ok := usePreLlm.(bool); ok {
|
|
||||||
req.UsePreLLM = usePreLlmVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setImageToImageParams 设置图生图参数
|
|
||||||
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
|
||||||
if imageInput, ok := params["image_input"].(string); ok {
|
|
||||||
req.ImageInput = imageInput
|
|
||||||
}
|
|
||||||
if gpen, ok := params["gpen"]; ok {
|
|
||||||
if gpenVal, ok := gpen.(float64); ok {
|
|
||||||
req.Gpen = gpenVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if skin, ok := params["skin"]; ok {
|
|
||||||
if skinVal, ok := skin.(float64); ok {
|
|
||||||
req.Skin = skinVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if skinUnifi, ok := params["skin_unifi"]; ok {
|
|
||||||
if skinUnifiVal, ok := skinUnifi.(float64); ok {
|
|
||||||
req.SkinUnifi = skinUnifiVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if genMode, ok := params["gen_mode"].(string); ok {
|
|
||||||
req.GenMode = genMode
|
|
||||||
}
|
|
||||||
s.setCommonParams(req, params) // 复用通用参数
|
|
||||||
}
|
|
||||||
|
|
||||||
// setImageEditParams 设置图像编辑参数
|
|
||||||
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
|
|
||||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
|
||||||
for _, url := range imageUrls {
|
|
||||||
if urlStr, ok := url.(string); ok {
|
|
||||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
delete(params, "size")
|
||||||
}
|
}
|
||||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
|
||||||
for _, data := range binaryData {
|
|
||||||
if dataStr, ok := data.(string); ok {
|
|
||||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if scale, ok := params["scale"]; ok {
|
|
||||||
if scaleVal, ok := scale.(float64); ok {
|
|
||||||
req.Scale = scaleVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.setCommonParams(req, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
// setImageEffectsParams 设置图像特效参数
|
// duration 转成 frames
|
||||||
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
|
if duration, ok := params["duration"]; ok {
|
||||||
if imageInput1, ok := params["image_input1"].(string); ok {
|
if secs, ok := duration.(int); ok {
|
||||||
req.ImageInput1 = imageInput1
|
params["frames"] = secs*24 + 1
|
||||||
}
|
|
||||||
if templateId, ok := params["template_id"].(string); ok {
|
|
||||||
req.TemplateId = templateId
|
|
||||||
}
|
|
||||||
if width, ok := params["width"]; ok {
|
|
||||||
if widthVal, ok := width.(float64); ok {
|
|
||||||
req.Width = int(widthVal)
|
|
||||||
}
|
}
|
||||||
|
delete(params, "duration")
|
||||||
}
|
}
|
||||||
if height, ok := params["height"]; ok {
|
|
||||||
if heightVal, ok := height.(float64); ok {
|
|
||||||
req.Height = int(heightVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setTextToVideoParams 设置文生视频参数
|
// 删除多余参数,剩下的就是各个任务自己专有参数了
|
||||||
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
delete(params, "type")
|
||||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
delete(params, "power")
|
||||||
req.AspectRatio = aspectRatio
|
return params, nil
|
||||||
}
|
|
||||||
s.setCommonParams(req, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
// setImageToVideoParams 设置图生视频参数
|
|
||||||
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
|
||||||
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
|
|
||||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
|
||||||
req.AspectRatio = aspectRatio
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setCommonParams 设置通用参数(seed, width, height等)
|
|
||||||
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
|
|
||||||
if seed, ok := params["seed"]; ok {
|
|
||||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
|
||||||
req.Seed = seedVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if width, ok := params["width"]; ok {
|
|
||||||
if widthVal, ok := width.(float64); ok {
|
|
||||||
req.Width = int(widthVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if height, ok := params["height"]; ok {
|
|
||||||
if heightVal, ok := height.(float64); ok {
|
|
||||||
req.Height = int(heightVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// pollTaskStatus 轮询任务状态
|
// pollTaskStatus 轮询任务状态
|
||||||
@@ -368,6 +270,11 @@ func (s *Service) pollTaskStatus() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 豆包生图 4.0 是同步任务,不需要轮询
|
||||||
|
if job.ReqKey == DoubaoSeedream40ReqKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// 查询任务状态
|
// 查询任务状态
|
||||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||||
ReqKey: job.ReqKey,
|
ReqKey: job.ReqKey,
|
||||||
@@ -384,7 +291,7 @@ func (s *Service) pollTaskStatus() {
|
|||||||
rawData, _ := json.Marshal(resp)
|
rawData, _ := json.Marshal(resp)
|
||||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
|
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
|
||||||
|
|
||||||
if resp.Code != 10000 {
|
if resp.Code != CodeSuccess {
|
||||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
|
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,32 +4,6 @@ import (
|
|||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SubmitTaskRequest 提交任务请求
|
|
||||||
type SubmitTaskRequest struct {
|
|
||||||
ReqKey string `json:"req_key"`
|
|
||||||
// 文生图参数
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
Seed int64 `json:"seed,omitempty"`
|
|
||||||
Scale float64 `json:"scale,omitempty"`
|
|
||||||
Width int `json:"width,omitempty"`
|
|
||||||
Height int `json:"height,omitempty"`
|
|
||||||
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
|
||||||
// 图生图参数
|
|
||||||
ImageInput string `json:"image_input,omitempty"`
|
|
||||||
ImageUrls []string `json:"image_urls,omitempty"`
|
|
||||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
|
||||||
Gpen float64 `json:"gpen,omitempty"`
|
|
||||||
Skin float64 `json:"skin,omitempty"`
|
|
||||||
SkinUnifi float64 `json:"skin_unifi,omitempty"`
|
|
||||||
GenMode string `json:"gen_mode,omitempty"`
|
|
||||||
// 图像编辑参数
|
|
||||||
// 图像特效参数
|
|
||||||
ImageInput1 string `json:"image_input1,omitempty"`
|
|
||||||
TemplateId string `json:"template_id,omitempty"`
|
|
||||||
// 视频生成参数
|
|
||||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SubmitTaskResponse 提交任务响应
|
// SubmitTaskResponse 提交任务响应
|
||||||
type SubmitTaskResponse struct {
|
type SubmitTaskResponse struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
@@ -75,6 +49,8 @@ type QueryTaskResponse struct {
|
|||||||
} `json:"data"`
|
} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const CodeSuccess = 10000
|
||||||
|
|
||||||
// CreateTaskRequest 创建任务请求
|
// CreateTaskRequest 创建任务请求
|
||||||
type CreateTaskRequest struct {
|
type CreateTaskRequest struct {
|
||||||
Type types.JMTaskType `json:"type"`
|
Type types.JMTaskType `json:"type"`
|
||||||
@@ -84,3 +60,8 @@ type CreateTaskRequest struct {
|
|||||||
ImageUrls []string `json:"image_urls,omitempty"`
|
ImageUrls []string `json:"image_urls,omitempty"`
|
||||||
Power int `json:"power,omitempty"`
|
Power int `json:"power,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ImageEffectReqKey = "i2i_multi_style_zx2x"
|
||||||
|
DoubaoSeedream40ReqKey = "doubao-seedream-4-0-250828"
|
||||||
|
)
|
||||||
|
|||||||
817
api/utils/media_duration.go
Normal file
817
api/utils/media_duration.go
Normal file
@@ -0,0 +1,817 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AudioDuration returns duration of an audio file.
|
||||||
|
// Supported formats: MP3, WAV (auto-detected by header)
|
||||||
|
func AudioDuration(path string) (time.Duration, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Peek first 12 bytes to detect format
|
||||||
|
head := make([]byte, 12)
|
||||||
|
n, err := io.ReadFull(f, head)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n < 12 {
|
||||||
|
return 0, errors.New("file too small")
|
||||||
|
}
|
||||||
|
|
||||||
|
// WAV: RIFF....WAVE
|
||||||
|
if string(head[0:4]) == "RIFF" && string(head[8:12]) == "WAVE" {
|
||||||
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return wavDuration(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MP3 can start with ID3 or frame sync 0xFFEx
|
||||||
|
if string(head[0:3]) == "ID3" || (head[0] == 0xFF && (head[1]&0xE0) == 0xE0) {
|
||||||
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return mp3Duration(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, errors.New("unsupported audio format")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AudioDurationFromURL downloads the url to a temp file and returns duration.
|
||||||
|
func AudioDurationFromURL(url string) (time.Duration, error) {
|
||||||
|
path, err := fetchURLToTemp(url, 30*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer os.Remove(path)
|
||||||
|
return AudioDuration(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoDurationMP4 returns duration of an MP4 file (MOV/MP4 base media).
|
||||||
|
func VideoDurationMP4(path string) (time.Duration, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
return mp4Duration(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoDurationMP4FromURL downloads the url to a temp file and returns duration.
|
||||||
|
func VideoDurationMP4FromURL(url string) (time.Duration, error) {
|
||||||
|
path, err := fetchURLToTemp(url, 30*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer os.Remove(path)
|
||||||
|
return VideoDurationMP4(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------- helpers ----------------------
|
||||||
|
|
||||||
|
func fetchURLToTemp(url string, timeout time.Duration) (string, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("http status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp, err := os.CreateTemp("", "media-*")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer tmp.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(tmp, resp.Body); err != nil {
|
||||||
|
path := tmp.Name()
|
||||||
|
_ = os.Remove(path)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return tmp.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------- WAV ----------------------
|
||||||
|
|
||||||
|
func wavDuration(r io.ReadSeeker) (time.Duration, error) {
|
||||||
|
// RIFF header already checked outside if needed. We parse chunks to get fmt and data.
|
||||||
|
// WAV little-endian
|
||||||
|
if _, err := r.Seek(0, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read RIFF header (12 bytes)
|
||||||
|
head := make([]byte, 12)
|
||||||
|
if _, err := io.ReadFull(r, head); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if string(head[0:4]) != "RIFF" || string(head[8:12]) != "WAVE" {
|
||||||
|
return 0, errors.New("invalid wav header")
|
||||||
|
}
|
||||||
|
|
||||||
|
var sampleRate uint32
|
||||||
|
var numChans uint16
|
||||||
|
var bitsPerSample uint16
|
||||||
|
var byteRate uint32
|
||||||
|
var dataSize uint32
|
||||||
|
|
||||||
|
for {
|
||||||
|
chunkHdr := make([]byte, 8)
|
||||||
|
if _, err := io.ReadFull(r, chunkHdr); err != nil {
|
||||||
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
ckID := string(chunkHdr[0:4])
|
||||||
|
ckSize := binary.LittleEndian.Uint32(chunkHdr[4:8])
|
||||||
|
|
||||||
|
switch ckID {
|
||||||
|
case "fmt ":
|
||||||
|
fmtData := make([]byte, ckSize)
|
||||||
|
if _, err := io.ReadFull(r, fmtData); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// audioFormat := binary.LittleEndian.Uint16(fmtData[0:2]) // 1 = PCM
|
||||||
|
numChans = binary.LittleEndian.Uint16(fmtData[2:4])
|
||||||
|
sampleRate = binary.LittleEndian.Uint32(fmtData[4:8])
|
||||||
|
byteRate = binary.LittleEndian.Uint32(fmtData[8:12])
|
||||||
|
// blockAlign := binary.LittleEndian.Uint16(fmtData[12:14])
|
||||||
|
if len(fmtData) >= 16 {
|
||||||
|
bitsPerSample = binary.LittleEndian.Uint16(fmtData[14:16])
|
||||||
|
}
|
||||||
|
case "data":
|
||||||
|
dataSize = ckSize
|
||||||
|
// Skip data content
|
||||||
|
if _, err := r.Seek(int64(ckSize), io.SeekCurrent); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Skip other chunks
|
||||||
|
if _, err := r.Seek(int64(ckSize), io.SeekCurrent); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Chunks are word-aligned (pad byte if odd size)
|
||||||
|
if ckSize%2 == 1 {
|
||||||
|
if _, err := r.Seek(1, io.SeekCurrent); err != nil { // skip pad
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sampleRate == 0 || numChans == 0 {
|
||||||
|
return 0, errors.New("invalid wav fmt")
|
||||||
|
}
|
||||||
|
|
||||||
|
var durationSeconds float64
|
||||||
|
if byteRate != 0 {
|
||||||
|
durationSeconds = float64(dataSize) / float64(byteRate)
|
||||||
|
} else {
|
||||||
|
bytesPerSec := float64(sampleRate) * float64(numChans) * float64(bitsPerSample) / 8.0
|
||||||
|
if bytesPerSec == 0 {
|
||||||
|
return 0, errors.New("invalid wav parameters")
|
||||||
|
}
|
||||||
|
durationSeconds = float64(dataSize) / bytesPerSec
|
||||||
|
}
|
||||||
|
return time.Duration(durationSeconds * float64(time.Second)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------- MP3 ----------------------
|
||||||
|
|
||||||
|
func mp3Duration(r io.ReadSeeker) (time.Duration, error) {
|
||||||
|
// Strategy:
|
||||||
|
// 1) Skip ID3v2 header if present.
|
||||||
|
// 2) Try read first frame and detect XING/Info or VBRI to get total frames and duration.
|
||||||
|
// 3) If VBR headers not present, fall back to CBR estimation: (audioDataBytes * 8) / bitrate.
|
||||||
|
|
||||||
|
// File size
|
||||||
|
fi, err := fileSizeFromSeeker(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip ID3v2
|
||||||
|
var id3v2Size int64
|
||||||
|
if _, err := r.Seek(0, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
id3v2Size, err = skipID3v2(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remember audio start offset
|
||||||
|
startOffset, _ := r.Seek(0, io.SeekCurrent)
|
||||||
|
|
||||||
|
// Read first frame header (search sync)
|
||||||
|
off, fh, err := findNextMP3Frame(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if _, err := r.Seek(off, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for XING/Info header in first frame (for VBR)
|
||||||
|
totalFrames, sampleRate, samplesPerFrame, bitrateKbps, vbrFound, err := parseFirstFrameForVBR(r, fh)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if vbrFound && totalFrames > 0 && sampleRate > 0 && samplesPerFrame > 0 {
|
||||||
|
seconds := (float64(totalFrames) * float64(samplesPerFrame)) / float64(sampleRate)
|
||||||
|
return time.Duration(seconds * float64(time.Second)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to CBR estimate using bitrate and data size (excluding ID3v2 and ID3v1)
|
||||||
|
// Detect ID3v1 at end (128 bytes TAG)
|
||||||
|
var id3v1Size int64
|
||||||
|
if fi >= 128 {
|
||||||
|
if _, err := r.Seek(fi-128, io.SeekStart); err == nil {
|
||||||
|
buf := make([]byte, 3)
|
||||||
|
if _, err := io.ReadFull(r, buf); err == nil {
|
||||||
|
if string(buf) == "TAG" {
|
||||||
|
id3v1Size = 128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
audioBytes := fi - id3v2Size - id3v1Size - startOffset
|
||||||
|
if audioBytes <= 0 || bitrateKbps == 0 {
|
||||||
|
return 0, errors.New("unable to estimate mp3 duration")
|
||||||
|
}
|
||||||
|
// bitrateKbps in kbps, bytes -> bits
|
||||||
|
seconds := float64(audioBytes*8) / float64(bitrateKbps*1000)
|
||||||
|
return time.Duration(seconds * float64(time.Second)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mp3FrameHeader struct {
|
||||||
|
Version int // 1: MPEG1, 2: MPEG2, 25: MPEG2.5
|
||||||
|
Layer int // 1,2,3
|
||||||
|
BitrateKbps int
|
||||||
|
SampleRate int
|
||||||
|
Padding int
|
||||||
|
ChannelMode int // 0:Stereo,1:Joint,2:Dual,3:Mono
|
||||||
|
}
|
||||||
|
|
||||||
|
func findNextMP3Frame(r io.ReadSeeker) (int64, mp3FrameHeader, error) {
|
||||||
|
var hdr mp3FrameHeader
|
||||||
|
// Start from current pos and scan up to 64KB
|
||||||
|
start, _ := r.Seek(0, io.SeekCurrent)
|
||||||
|
limit := int64(64 * 1024)
|
||||||
|
buf := make([]byte, limit)
|
||||||
|
n, err := r.Read(buf)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return 0, hdr, err
|
||||||
|
}
|
||||||
|
for i := 0; i+4 <= n; i++ {
|
||||||
|
if buf[i] == 0xFF && (buf[i+1]&0xE0) == 0xE0 { // sync
|
||||||
|
if h, ok := parseMP3Header(buf[i : i+4]); ok {
|
||||||
|
offset := start + int64(i)
|
||||||
|
return offset, h, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, hdr, errors.New("mp3 frame not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMP3Header(b []byte) (mp3FrameHeader, bool) {
|
||||||
|
var h mp3FrameHeader
|
||||||
|
if len(b) < 4 {
|
||||||
|
return h, false
|
||||||
|
}
|
||||||
|
if b[0] != 0xFF || (b[1]&0xE0) != 0xE0 {
|
||||||
|
return h, false
|
||||||
|
}
|
||||||
|
versionBits := (b[1] >> 3) & 0x03
|
||||||
|
layerBits := (b[1] >> 1) & 0x03
|
||||||
|
bitrateBits := (b[2] >> 4) & 0x0F
|
||||||
|
sampleRateBits := (b[2] >> 2) & 0x03
|
||||||
|
paddingBit := (b[2] >> 1) & 0x01
|
||||||
|
channelMode := (b[3] >> 6) & 0x03
|
||||||
|
|
||||||
|
var version int
|
||||||
|
switch versionBits {
|
||||||
|
case 0x00:
|
||||||
|
version = 25 // MPEG 2.5
|
||||||
|
case 0x02:
|
||||||
|
version = 2 // MPEG 2
|
||||||
|
case 0x03:
|
||||||
|
version = 1 // MPEG 1
|
||||||
|
default:
|
||||||
|
return h, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var layer int
|
||||||
|
switch layerBits {
|
||||||
|
case 0x01:
|
||||||
|
layer = 3
|
||||||
|
case 0x02:
|
||||||
|
layer = 2
|
||||||
|
case 0x03:
|
||||||
|
layer = 1
|
||||||
|
default:
|
||||||
|
return h, false
|
||||||
|
}
|
||||||
|
|
||||||
|
br := mp3BitrateKbps(version, layer, int(bitrateBits))
|
||||||
|
if br == 0 {
|
||||||
|
return h, false
|
||||||
|
}
|
||||||
|
sr := mp3SampleRate(version, int(sampleRateBits))
|
||||||
|
if sr == 0 {
|
||||||
|
return h, false
|
||||||
|
}
|
||||||
|
|
||||||
|
h = mp3FrameHeader{
|
||||||
|
Version: version,
|
||||||
|
Layer: layer,
|
||||||
|
BitrateKbps: br,
|
||||||
|
SampleRate: sr,
|
||||||
|
Padding: int(paddingBit),
|
||||||
|
ChannelMode: int(channelMode),
|
||||||
|
}
|
||||||
|
return h, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func mp3BitrateKbps(version, layer, index int) int {
|
||||||
|
// index: 1..14 valid; 0,15 invalid
|
||||||
|
if index <= 0 || index == 15 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// Tables per ISO/IEC 11172-3/13818-3 (common subset)
|
||||||
|
var tbl [15]int
|
||||||
|
if layer == 1 { // Layer I
|
||||||
|
if version == 1 { // MPEG1
|
||||||
|
tbl = [15]int{0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448}
|
||||||
|
} else { // MPEG2/2.5
|
||||||
|
tbl = [15]int{0, 32, 48, 56, 64, 80, 96, 112, 128, 144, 160, 176, 192, 224, 256}
|
||||||
|
}
|
||||||
|
} else if layer == 2 { // Layer II
|
||||||
|
if version == 1 {
|
||||||
|
tbl = [15]int{0, 32, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384}
|
||||||
|
} else {
|
||||||
|
tbl = [15]int{0, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160}
|
||||||
|
}
|
||||||
|
} else { // Layer III
|
||||||
|
if version == 1 {
|
||||||
|
tbl = [15]int{0, 32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320}
|
||||||
|
} else {
|
||||||
|
tbl = [15]int{0, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tbl[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
func mp3SampleRate(version, index int) int {
|
||||||
|
if index == 3 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// base table for MPEG1
|
||||||
|
base := [3]int{44100, 48000, 32000}
|
||||||
|
sr := base[index]
|
||||||
|
if version == 2 { // MPEG2
|
||||||
|
sr /= 2
|
||||||
|
} else if version == 25 { // MPEG2.5
|
||||||
|
sr /= 4
|
||||||
|
}
|
||||||
|
return sr
|
||||||
|
}
|
||||||
|
|
||||||
|
func samplesPerMP3Frame(version, layer int) int {
|
||||||
|
switch layer {
|
||||||
|
case 1:
|
||||||
|
return 384
|
||||||
|
case 2:
|
||||||
|
return 1152
|
||||||
|
case 3:
|
||||||
|
if version == 1 {
|
||||||
|
return 1152
|
||||||
|
}
|
||||||
|
return 576 // MPEG2/2.5 Layer III
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFirstFrameForVBR(r io.ReadSeeker, fh mp3FrameHeader) (totalFrames uint32, sampleRate int, samplesPerFrame int, bitrateKbps int, vbrFound bool, err error) {
|
||||||
|
// After the 4-byte header, possible side info and then XING/Info
|
||||||
|
if _, err = r.Seek(0, io.SeekCurrent); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Re-read header
|
||||||
|
hdr := make([]byte, 4)
|
||||||
|
if _, err = io.ReadFull(r, hdr); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// side info size depends on MPEG version and channel mode (for Layer III)
|
||||||
|
sideInfoSize := 0
|
||||||
|
if fh.Layer == 3 { // Layer III
|
||||||
|
if fh.Version == 1 { // MPEG1
|
||||||
|
if fh.ChannelMode == 3 { // mono
|
||||||
|
sideInfoSize = 17
|
||||||
|
} else {
|
||||||
|
sideInfoSize = 32
|
||||||
|
}
|
||||||
|
} else { // MPEG2/2.5
|
||||||
|
if fh.ChannelMode == 3 {
|
||||||
|
sideInfoSize = 9
|
||||||
|
} else {
|
||||||
|
sideInfoSize = 17
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read next up to 120 bytes to search for XING/Info or VBRI
|
||||||
|
buf := make([]byte, sideInfoSize+120)
|
||||||
|
if _, err = io.ReadFull(r, buf); err != nil {
|
||||||
|
// If short, still try within available
|
||||||
|
if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search XING/Info signature
|
||||||
|
sigs := [][]byte{[]byte("Xing"), []byte("Info")}
|
||||||
|
for _, sig := range sigs {
|
||||||
|
idx := indexOf(buf, sig)
|
||||||
|
if idx >= 0 {
|
||||||
|
// flags after signature (4 bytes), then if frames flag set, 4 bytes frames
|
||||||
|
if len(buf) >= idx+4+4 {
|
||||||
|
flags := binary.BigEndian.Uint32(buf[idx+4 : idx+8])
|
||||||
|
var frames uint32
|
||||||
|
if (flags & 0x01) != 0 { // frames present
|
||||||
|
if len(buf) >= idx+8+4 {
|
||||||
|
frames = binary.BigEndian.Uint32(buf[idx+8 : idx+12])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if frames > 0 {
|
||||||
|
vbrFound = true
|
||||||
|
totalFrames = frames
|
||||||
|
sampleRate = fh.SampleRate
|
||||||
|
samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer)
|
||||||
|
bitrateKbps = fh.BitrateKbps
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check VBRI (usually at 32 bytes after header for MPEG1 Layer III)
|
||||||
|
if len(buf) >= 4 {
|
||||||
|
idx := indexOf(buf, []byte("VBRI"))
|
||||||
|
if idx >= 0 {
|
||||||
|
if len(buf) >= idx+4+2+2+4+4 {
|
||||||
|
// VBRI layout: 'VBRI'(4) + version(2) + delay(2) + quality(2?) varies; but at offset 10 comes bytes: bytes (4), frames (4)
|
||||||
|
// Some docs: offset 10: bytes, offset 14: frames (big-endian)
|
||||||
|
bytesOffset := idx + 10
|
||||||
|
framesOffset := idx + 14
|
||||||
|
if len(buf) >= framesOffset+4 {
|
||||||
|
frames := binary.BigEndian.Uint32(buf[framesOffset : framesOffset+4])
|
||||||
|
if frames > 0 {
|
||||||
|
vbrFound = true
|
||||||
|
totalFrames = frames
|
||||||
|
sampleRate = fh.SampleRate
|
||||||
|
samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer)
|
||||||
|
bitrateKbps = fh.BitrateKbps
|
||||||
|
_ = bytesOffset // not used
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No VBR header. Provide header info for CBR fallback
|
||||||
|
sampleRate = fh.SampleRate
|
||||||
|
samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer)
|
||||||
|
bitrateKbps = fh.BitrateKbps
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexOf(haystack []byte, needle []byte) int {
|
||||||
|
for i := 0; i+len(needle) <= len(haystack); i++ {
|
||||||
|
match := true
|
||||||
|
for j := 0; j < len(needle); j++ {
|
||||||
|
if haystack[i+j] != needle[j] {
|
||||||
|
match = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipID3v2(r io.ReadSeeker) (int64, error) {
|
||||||
|
if _, err := r.Seek(0, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
head := make([]byte, 10)
|
||||||
|
if _, err := io.ReadFull(r, head); err != nil {
|
||||||
|
return 0, nil // no header
|
||||||
|
}
|
||||||
|
if string(head[0:3]) != "ID3" {
|
||||||
|
if _, err := r.Seek(0, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
// size: 4 synchsafe bytes
|
||||||
|
sz := int64((int(head[6]&0x7F) << 21) | (int(head[7]&0x7F) << 14) | (int(head[8]&0x7F) << 7) | int(head[9]&0x7F))
|
||||||
|
// total header size = 10 + sz (+ footer 10 if flag set)
|
||||||
|
footer := int64(0)
|
||||||
|
if (head[5] & 0x10) != 0 { // footer present
|
||||||
|
footer = 10
|
||||||
|
}
|
||||||
|
total := 10 + sz + footer
|
||||||
|
if _, err := r.Seek(total, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileSizeFromSeeker(r io.ReadSeeker) (int64, error) {
|
||||||
|
cur, err := r.Seek(0, io.SeekCurrent)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
end, err := r.Seek(0, io.SeekEnd)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if _, err := r.Seek(cur, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------- MP4 ----------------------
|
||||||
|
|
||||||
|
type mp4BoxHeader struct {
|
||||||
|
Size uint64
|
||||||
|
Type [4]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBoxHeader(r io.ReadSeeker) (mp4BoxHeader, error) {
|
||||||
|
var h mp4BoxHeader
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return h, err
|
||||||
|
}
|
||||||
|
sz := binary.BigEndian.Uint32(buf[0:4])
|
||||||
|
copy(h.Type[:], buf[4:8])
|
||||||
|
if sz == 1 {
|
||||||
|
// 64-bit size follows
|
||||||
|
ext := make([]byte, 8)
|
||||||
|
if _, err := io.ReadFull(r, ext); err != nil {
|
||||||
|
return h, err
|
||||||
|
}
|
||||||
|
h.Size = binary.BigEndian.Uint64(ext)
|
||||||
|
} else {
|
||||||
|
h.Size = uint64(sz)
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipBox(r io.ReadSeeker, boxSize uint64, alreadyRead int64) error {
|
||||||
|
toSkip := int64(boxSize) - alreadyRead
|
||||||
|
if toSkip < 0 {
|
||||||
|
return fmt.Errorf("invalid box size")
|
||||||
|
}
|
||||||
|
_, err := r.Seek(toSkip, io.SeekCurrent)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func mp4Duration(r io.ReadSeeker) (time.Duration, error) {
|
||||||
|
if _, err := r.Seek(0, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// Find moov box
|
||||||
|
var moovStart int64
|
||||||
|
var moovSize uint64
|
||||||
|
for {
|
||||||
|
pos, _ := r.Seek(0, io.SeekCurrent)
|
||||||
|
h, err := readBoxHeader(r)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if string(h.Type[:]) == "moov" {
|
||||||
|
moovStart = pos
|
||||||
|
moovSize = h.Size
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if h.Size < 8 {
|
||||||
|
return 0, errors.New("invalid mp4 box size")
|
||||||
|
}
|
||||||
|
if err := skipBox(r, h.Size, 8); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if moovStart == 0 && moovSize == 0 {
|
||||||
|
return 0, errors.New("moov not found")
|
||||||
|
}
|
||||||
|
// Parse inside moov for video trak mdhd, else mvhd
|
||||||
|
if _, err := r.Seek(moovStart+8, io.SeekStart); err != nil { // skip moov header
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
end := moovStart + int64(moovSize)
|
||||||
|
var movieTimescale uint32
|
||||||
|
var movieDuration uint64
|
||||||
|
var foundVideoMdhd bool
|
||||||
|
var mdhdTimescale uint32
|
||||||
|
var mdhdDuration uint64
|
||||||
|
|
||||||
|
for {
|
||||||
|
pos, _ := r.Seek(0, io.SeekCurrent)
|
||||||
|
if pos >= end {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
h, err := readBoxHeader(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
switch string(h.Type[:]) {
|
||||||
|
case "mvhd":
|
||||||
|
// movie header
|
||||||
|
ver := make([]byte, 1)
|
||||||
|
if _, err := io.ReadFull(r, ver); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if _, err := r.Seek(3, io.SeekCurrent); err != nil { // flags
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if ver[0] == 1 {
|
||||||
|
// version 1: 64-bit duration
|
||||||
|
buf := make([]byte, 8+8+4+8) // ctime(8) mtime(8) timescale(4) duration(8)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
movieTimescale = binary.BigEndian.Uint32(buf[16:20])
|
||||||
|
movieDuration = binary.BigEndian.Uint64(buf[20:28])
|
||||||
|
} else {
|
||||||
|
buf := make([]byte, 4+4+4+4) // ctime mtime timescale duration
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
movieTimescale = binary.BigEndian.Uint32(buf[8:12])
|
||||||
|
movieDuration = uint64(binary.BigEndian.Uint32(buf[12:16]))
|
||||||
|
}
|
||||||
|
// skip rest of mvhd
|
||||||
|
read := int64(1 + 3)
|
||||||
|
if ver[0] == 1 {
|
||||||
|
read += int64(8 + 8 + 4 + 8)
|
||||||
|
} else {
|
||||||
|
read += int64(4 + 4 + 4 + 4)
|
||||||
|
}
|
||||||
|
if err := skipBox(r, h.Size, 8+read); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
case "trak":
|
||||||
|
// parse trak for hdlr and mdhd
|
||||||
|
tEnd := int64(0)
|
||||||
|
if h.Size < 8 {
|
||||||
|
return 0, errors.New("invalid trak size")
|
||||||
|
}
|
||||||
|
tEnd = pos + int64(h.Size)
|
||||||
|
var isVideo bool
|
||||||
|
var tMdhdTimescale uint32
|
||||||
|
var tMdhdDuration uint64
|
||||||
|
for {
|
||||||
|
cpos, _ := r.Seek(0, io.SeekCurrent)
|
||||||
|
if cpos >= tEnd {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
ch, err := readBoxHeader(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
switch string(ch.Type[:]) {
|
||||||
|
case "mdia":
|
||||||
|
mEnd := cpos + int64(ch.Size)
|
||||||
|
for {
|
||||||
|
mpos, _ := r.Seek(0, io.SeekCurrent)
|
||||||
|
if mpos >= mEnd {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
mh, err := readBoxHeader(r)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
switch string(mh.Type[:]) {
|
||||||
|
case "hdlr":
|
||||||
|
// skip version+flags (4), pre_defined(4)
|
||||||
|
if _, err := r.Seek(8, io.SeekCurrent); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
handler := make([]byte, 4)
|
||||||
|
if _, err := io.ReadFull(r, handler); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if string(handler) == "vide" {
|
||||||
|
isVideo = true
|
||||||
|
}
|
||||||
|
if err := skipBox(r, mh.Size, 8+8+4); err != nil { // header + skipped + read handler
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
case "mdhd":
|
||||||
|
ver := make([]byte, 1)
|
||||||
|
if _, err := io.ReadFull(r, ver); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if _, err := r.Seek(3, io.SeekCurrent); err != nil { // flags
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if ver[0] == 1 {
|
||||||
|
buf := make([]byte, 8+8+4+8)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
tMdhdTimescale = binary.BigEndian.Uint32(buf[16:20])
|
||||||
|
tMdhdDuration = binary.BigEndian.Uint64(buf[20:28])
|
||||||
|
} else {
|
||||||
|
buf := make([]byte, 4+4+4+4)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
tMdhdTimescale = binary.BigEndian.Uint32(buf[8:12])
|
||||||
|
tMdhdDuration = uint64(binary.BigEndian.Uint32(buf[12:16]))
|
||||||
|
}
|
||||||
|
if err := skipBox(r, mh.Size, 8+1+3+int64(lenVersionPayload(ver[0]))); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if err := skipBox(r, mh.Size, 8); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if err := skipBox(r, ch.Size, 8); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isVideo && tMdhdTimescale != 0 && tMdhdDuration != 0 {
|
||||||
|
foundVideoMdhd = true
|
||||||
|
mdhdTimescale = tMdhdTimescale
|
||||||
|
mdhdDuration = tMdhdDuration
|
||||||
|
}
|
||||||
|
// Skip remaining of trak if any
|
||||||
|
if _, err := r.Seek(tEnd, io.SeekStart); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if err := skipBox(r, h.Size, 8); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if foundVideoMdhd && mdhdTimescale != 0 {
|
||||||
|
sec := float64(mdhdDuration) / float64(mdhdTimescale)
|
||||||
|
return time.Duration(sec * float64(time.Second)), nil
|
||||||
|
}
|
||||||
|
if movieTimescale != 0 {
|
||||||
|
sec := float64(movieDuration) / float64(movieTimescale)
|
||||||
|
return time.Duration(sec * float64(time.Second)), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("failed to read mp4 duration")
|
||||||
|
}
|
||||||
|
|
||||||
|
func lenVersionPayload(ver byte) int {
|
||||||
|
if ver == 1 {
|
||||||
|
return 8 + 8 + 4 + 8
|
||||||
|
}
|
||||||
|
return 4 + 4 + 4 + 4
|
||||||
|
}
|
||||||
@@ -225,7 +225,7 @@ const initModelValue = (model) => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
defaultValues.model = selectedModel.value.key
|
defaultValues.req_key = selectedModel.value.key
|
||||||
return defaultValues
|
return defaultValues
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ const props = defineProps({
|
|||||||
|
|
||||||
const typeClass = computed(() => {
|
const typeClass = computed(() => {
|
||||||
return {
|
return {
|
||||||
info: 'bg-blue-100 text-blue-500 border-blue-500',
|
info: 'bg-blue-100 text-white-500 border-blue-500',
|
||||||
success: 'bg-green-100 text-green-500 border-green-500',
|
success: 'bg-green-100 text-green-500 border-green-500',
|
||||||
warning: 'bg-yellow-100 text-yellow-500 border-yellow-500',
|
warning: 'bg-yellow-100 text-yellow-500 border-yellow-500',
|
||||||
}[props.type]
|
}[props.type]
|
||||||
|
|||||||
@@ -537,6 +537,22 @@ export const JimengParams = {
|
|||||||
label: '21:9 (2016 * 864)',
|
label: '21:9 (2016 * 864)',
|
||||||
value: '2016x864',
|
value: '2016x864',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
label: '9:16 (936 * 1664)',
|
||||||
|
value: '936x1664',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '2:3 (1056 * 1584)',
|
||||||
|
value: '1056x1584',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '3:4 (1104 * 1472)',
|
||||||
|
value: '1104x1472',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: '9:21 (864 * 2016)',
|
||||||
|
value: '864x2016',
|
||||||
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -195,14 +195,19 @@ export const useJimengStore = defineStore('jimeng', () => {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
|
formData.value.type = activeFunction.value
|
||||||
const response = await httpPost('/api/jimeng/task', formData.value)
|
// 视频 duration 转成整数
|
||||||
if (response.data) {
|
if (formData.value.duration) {
|
||||||
showMessageOK('任务提交成功')
|
formData.value.duration = parseInt(formData.value.duration)
|
||||||
isOver.value = false
|
|
||||||
await fetchData(1)
|
|
||||||
startPolling()
|
|
||||||
}
|
}
|
||||||
|
if (formData.value.image_urls && !Array.isArray(formData.value.image_urls)) {
|
||||||
|
formData.value.image_urls = [formData.value.image_urls]
|
||||||
|
}
|
||||||
|
const response = await httpPost('/api/jimeng/task', formData.value)
|
||||||
|
showMessageOK('任务提交成功')
|
||||||
|
isOver.value = false
|
||||||
|
await fetchData(1)
|
||||||
|
startPolling()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('提交任务失败:', error)
|
console.error('提交任务失败:', error)
|
||||||
showMessageError(error.message || '提交任务失败')
|
showMessageError(error.message || '提交任务失败')
|
||||||
|
|||||||
@@ -24,10 +24,15 @@
|
|||||||
>
|
>
|
||||||
和
|
和
|
||||||
<a
|
<a
|
||||||
href="https://console.volcengine.com/ai/ability/detail/9"
|
href="https://console.volcengine.com/ai/ability/detail/1"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
class="text-blue-500"
|
class="text-blue-500"
|
||||||
>智能绘图</a
|
>智能绘图</a
|
||||||
|
>以及<a
|
||||||
|
href="https://console.volcengine.com/ark/region:ark+cn-beijing/openManagement"
|
||||||
|
target="_blank"
|
||||||
|
class="text-blue-500"
|
||||||
|
>火山方舟</a
|
||||||
>
|
>
|
||||||
服务。
|
服务。
|
||||||
</p>
|
</p>
|
||||||
@@ -41,6 +46,17 @@
|
|||||||
>
|
>
|
||||||
获取。
|
获取。
|
||||||
</p>
|
</p>
|
||||||
|
<p>
|
||||||
|
3. ApiKey 请在火山方舟控制台 ->
|
||||||
|
<a
|
||||||
|
href="https://console.volcengine.com/ark/region:ark+cn-beijing/apiKey?apikey=%7B%7D"
|
||||||
|
target="_blank"
|
||||||
|
class="text-blue-500"
|
||||||
|
>
|
||||||
|
API Key管理</a
|
||||||
|
>
|
||||||
|
获取。
|
||||||
|
</p>
|
||||||
</Alert>
|
</Alert>
|
||||||
</div>
|
</div>
|
||||||
<el-form-item label="AccessKey" prop="access_key">
|
<el-form-item label="AccessKey" prop="access_key">
|
||||||
@@ -49,6 +65,19 @@
|
|||||||
<el-form-item label="SecretKey" prop="secret_key">
|
<el-form-item label="SecretKey" prop="secret_key">
|
||||||
<el-input v-model="jimengConfig.secret_key" placeholder="请输入即梦AI的SecretKey" />
|
<el-input v-model="jimengConfig.secret_key" placeholder="请输入即梦AI的SecretKey" />
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
<el-form-item prop="api_key">
|
||||||
|
<template #label>
|
||||||
|
<div class="text-sm">
|
||||||
|
火山方舟服务API Key(<span class="text-red-400"
|
||||||
|
>目前火山方舟服务只支持API Key验证</span
|
||||||
|
>)
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<el-input v-model="jimengConfig.api_key" placeholder="请输入火山方舟服务API Key" />
|
||||||
|
<div class="text-sm mt-2 text-gray-500">
|
||||||
|
目前豆包生图 4.0 模型在即梦API中不支持,需要使用火山方舟服务。
|
||||||
|
</div>
|
||||||
|
</el-form-item>
|
||||||
</div>
|
</div>
|
||||||
<el-divider />
|
<el-divider />
|
||||||
<!-- 算力配置分组 -->
|
<!-- 算力配置分组 -->
|
||||||
|
|||||||
Reference in New Issue
Block a user