Merge tag 'v4.2.7'

This commit is contained in:
RockYang
2026-05-09 20:56:07 +08:00
124 changed files with 6679 additions and 4580 deletions

View File

@@ -89,6 +89,7 @@ type BaseConfig struct {
MjMode string `json:"mj_mode"` // midjourney 默认的API模式relax, fast, turbo
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
IndexPage string `json:"index_page"` // 首页显示的页面
Copyright string `json:"copyright"` // 版权信息
ICP string `json:"icp"` // ICP 备案号
GaBeian string `json:"ga_beian"` // 公安备案号

View File

@@ -20,14 +20,14 @@ func init() {
// CaptchaConfig 行为验证码配置
type CaptchaConfig struct {
ApiKey string `json:"api_key"`
Type string `json:"type"` // 验证码类型, 可选值: "dot" 或 "slide"
Enabled bool `json:"enabled"`
ApiKey string `json:"api_key,omitempty"`
Type string `json:"type,omitempty"` // 验证码类型, 可选值: "dot" 或 "slide"
Enabled bool `json:"enabled,omitempty"`
}
// WxLoginConfig 微信登录配置
type WxLoginConfig struct {
ApiKey string `json:"api_key"`
NotifyURL string `json:"notify_url"` // 登录成功回调 URL
Enabled bool `json:"enabled"` // 是否启用微信登录
ApiKey string `json:"api_key,omitempty"`
NotifyURL string `json:"notify_url,omitempty"` // 登录成功回调 URL
Enabled bool `json:"enabled,omitempty"` // 是否启用微信登录
}

View File

@@ -2,17 +2,64 @@ package types
// JimengConfig 即梦AI配置
type JimengConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
Power JimengPower `json:"power"`
// 即梦AI的AccessKey和SecretKey
AccessKey string `json:"access_key,omitempty"`
SecretKey string `json:"secret_key,omitempty"`
// 火山引擎大模型专用的验证方式
ApiKey string `json:"api_key,omitempty"`
// 算力配置
Powers map[string]int `json:"powers,omitempty"`
}
// JimengPower 即梦AI算力配置
type JimengPower struct {
TextToImage int `json:"text_to_image"`
ImageToImage int `json:"image_to_image"`
ImageEdit int `json:"image_edit"`
ImageEffects int `json:"image_effects"`
TextToVideo int `json:"text_to_video"`
ImageToVideo int `json:"image_to_video"`
// JMTaskStatus 任务状态
type JMTaskStatus string
const (
JMTaskStatusInQueue = JMTaskStatus("in_queue") // 任务已提交
JMTaskStatusGenerating = JMTaskStatus("generating") // 任务处理中
JMTaskStatusDone = JMTaskStatus("done") // 处理完成
JMTaskStatusNotFound = JMTaskStatus("not_found") // 任务未找到
JMTaskStatusSuccess = JMTaskStatus("success") // 任务成功
JMTaskStatusFailed = JMTaskStatus("failed") // 任务失败
JMTaskStatusExpired = JMTaskStatus("expired") // 任务过期
)
// JMTaskType 任务类型
type JMTaskType string
const (
JMTaskTypeImage = JMTaskType("image") // 文生图
JMTaskTypeVideo = JMTaskType("video") // 图生图
JMTaskTypeVirtualHuman = JMTaskType("virtual_human") // 图像编辑
JMTaskTypeActionTransfer = JMTaskType("action_transfer") // 图像特效
)
// JimengTaskRequest 即梦AI任务请求
type JimengTaskRequest struct {
TaskType JMTaskType `json:"type"` // 任务类型
ReqKey string `json:"req_key"` // 请求Key
Action string `json:"action"` // 请求Action
Power int `json:"power"` // 消耗算力
// 公共参数
Prompt string `json:"prompt,omitempty"`
ImageUrls []string `json:"image_urls,omitempty"`
// 图片生成参数
Size string `json:"size,omitempty"`
UsePreLLM bool `json:"use_pre_llm,omitempty"`
Scale float64 `json:"scale,omitempty"`
ForceSingle bool `json:"force_single,omitempty"`
// 视频生成参数
Duration int `json:"duration,omitempty"` // 视频时长,单位:秒
TemplateId string `json:"template_id,omitempty"` // 运镜模板ID
AspectRatio string `json:"aspect_ratio,omitempty"`
CameraStrength string `json:"camera_strength,omitempty"` // 运镜强度
// 数字人视频生成参数
AudioURL string `json:"audio_url,omitempty"` // 音频URL
RecognizeKey string `json:"recognize_key,omitempty"` // 识别主体请求Key
// 视频动作迁移参数
VideoURL string `json:"video_url,omitempty"` // 动作视频URL
}

View File

@@ -9,13 +9,13 @@ package types
// 文本审查
type ModerationConfig struct {
Enable bool `json:"enable"` // 是否启用文本审查
Active string `json:"active"`
EnableGuide bool `json:"enable_guide"` // 是否启用模型引导提示词
GuidePrompt string `json:"guide_prompt"` // 模型引导提示词
Gitee ModerationGiteeConfig `json:"gitee"`
Baidu ModerationBaiduConfig `json:"baidu"`
Tencent ModerationTencentConfig `json:"tencent"`
Enable bool `json:"enable,omitempty"` // 是否启用文本审查
Active string `json:"active,omitempty"`
EnableGuide bool `json:"enable_guide,omitempty"` // 是否启用模型引导提示词
GuidePrompt string `json:"guide_prompt,omitempty"` // 模型引导提示词
Gitee ModerationGiteeConfig `json:"gitee,omitempty"`
Baidu ModerationBaiduConfig `json:"baidu,omitempty"`
Tencent ModerationTencentConfig `json:"tencent,omitempty"`
}
const (
@@ -26,26 +26,26 @@ const (
// GiteeAI 文本审查配置
type ModerationGiteeConfig struct {
ApiKey string `json:"api_key"`
Model string `json:"model"` // 文本审核模型
ApiKey string `json:"api_key,omitempty"`
Model string `json:"model,omitempty"` // 文本审核模型
}
// 百度文本审查配置
type ModerationBaiduConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
AccessKey string `json:"access_key,omitempty"`
SecretKey string `json:"secret_key,omitempty"`
}
// 腾讯云文本审查配置
type ModerationTencentConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
AccessKey string `json:"access_key,omitempty"`
SecretKey string `json:"secret_key,omitempty"`
}
type ModerationResult struct {
Flagged bool `json:"flagged"`
Categories map[string]bool `json:"categories"`
CategoryScores map[string]float64 `json:"category_scores"`
Flagged bool `json:"flagged,omitempty"`
Categories map[string]bool `json:"categories,omitempty"`
CategoryScores map[string]float64 `json:"category_scores,omitempty"`
}
var ModerationCategories = map[string]string{

View File

@@ -8,39 +8,39 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type OSSConfig struct {
Active string `json:"active"`
Local LocalStorageConfig `json:"local"`
Minio MiniOssConfig `json:"minio"`
QiNiu QiNiuOssConfig `json:"qiniu"`
AliYun AliYunOssConfig `json:"aliyun"`
Active string `json:"active,omitempty"`
Local LocalStorageConfig `json:"local,omitempty"`
Minio MiniOssConfig `json:"minio,omitempty"`
QiNiu QiNiuOssConfig `json:"qiniu,omitempty"`
AliYun AliYunOssConfig `json:"aliyun,omitempty"`
}
type MiniOssConfig struct {
Endpoint string `json:"endpoint"`
AccessKey string `json:"access_key"`
AccessSecret string `json:"access_secret"`
Bucket string `json:"bucket"`
UseSSL bool `json:"use_ssl"`
Domain string `json:"domain"`
Endpoint string `json:"endpoint,omitempty"`
AccessKey string `json:"access_key,omitempty"`
AccessSecret string `json:"access_secret,omitempty"`
Bucket string `json:"bucket,omitempty"`
UseSSL bool `json:"use_ssl,omitempty"`
Domain string `json:"domain,omitempty"`
}
type QiNiuOssConfig struct {
Zone string `json:"zone"`
AccessKey string `json:"access_key"`
AccessSecret string `json:"access_secret"`
Bucket string `json:"bucket"`
Domain string `json:"domain"`
Zone string `json:"zone,omitempty"`
AccessKey string `json:"access_key,omitempty"`
AccessSecret string `json:"access_secret,omitempty"`
Bucket string `json:"bucket,omitempty"`
Domain string `json:"domain,omitempty"`
}
type AliYunOssConfig struct {
Endpoint string `json:"endpoint"`
AccessKey string `json:"access_key"`
AccessSecret string `json:"access_secret"`
Bucket string `json:"bucket"`
Domain string `json:"domain"`
Endpoint string `json:"endpoint,omitempty"`
AccessKey string `json:"access_key,omitempty"`
AccessSecret string `json:"access_secret,omitempty"`
Bucket string `json:"bucket,omitempty"`
Domain string `json:"domain,omitempty"`
}
type LocalStorageConfig struct {
BasePath string `json:"base_path"`
BaseURL string `json:"base_url"`
BasePath string `json:"base_path,omitempty"`
BaseURL string `json:"base_url,omitempty"`
}

View File

@@ -1,19 +1,19 @@
package types
type PaymentConfig struct {
Alipay AlipayConfig `json:"alipay"` // 支付宝支付渠道配置
Epay EpayConfig `json:"epay"` // 易支付配置
WxPay WxPayConfig `json:"wxpay"` // 微信支付渠道配置
Alipay AlipayConfig `json:"alipay,omitempty"` // 支付宝支付渠道配置
Epay EpayConfig `json:"epay,omitempty"` // 易支付配置
WxPay WxPayConfig `json:"wxpay,omitempty"` // 微信支付渠道配置
}
// AlipayConfig 支付宝支付配置
type AlipayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
SandBox bool `json:"sandbox"` // 是否沙盒环境
AppId string `json:"app_id"` // 应用 ID
PrivateKey string `json:"private_key"` // 应用私钥
AlipayPublicKey string `json:"alipay_public_key"` // 支付宝公钥
Domain string `json:"domain"` // 支付回调域名
Enabled bool `json:"enabled,omitempty"` // 是否启用该支付通道
SandBox bool `json:"sandbox,omitempty"` // 是否沙盒环境
AppId string `json:"app_id,omitempty"` // 应用 ID
PrivateKey string `json:"private_key,omitempty"` // 应用私钥
AlipayPublicKey string `json:"alipay_public_key,omitempty"` // 支付宝公钥
Domain string `json:"domain,omitempty"` // 支付回调域名
}
func (c *AlipayConfig) Equal(other *AlipayConfig) bool {
@@ -25,13 +25,13 @@ func (c *AlipayConfig) Equal(other *AlipayConfig) bool {
// WxPayConfig 微信支付配置
type WxPayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
AppId string `json:"app_id"` // 公众号的APPID,如wxd678efh567hg6787
MchId string `json:"mch_id"` // 直连商户的商户号,由微信支付生成并下发
SerialNo string `json:"serial_no"` // 商户证书的证书序列号
PrivateKey string `json:"private_key"` // 商户证书私钥
ApiV3Key string `json:"api_v3_key"` // API V3 秘钥
Domain string `json:"domain"` // 支付回调域名
Enabled bool `json:"enabled,omitempty"` // 是否启用该支付通道
AppId string `json:"app_id,omitempty"` // 公众号的APPID,如wxd678efh567hg6787
MchId string `json:"mch_id,omitempty"` // 直连商户的商户号,由微信支付生成并下发
SerialNo string `json:"serial_no,omitempty"` // 商户证书的证书序列号
PrivateKey string `json:"private_key,omitempty"` // 商户证书私钥
ApiV3Key string `json:"api_v3_key,omitempty"` // API V3 秘钥
Domain string `json:"domain,omitempty"` // 支付回调域名
}
func (c *WxPayConfig) Equal(other *WxPayConfig) bool {
@@ -45,11 +45,11 @@ func (c *WxPayConfig) Equal(other *WxPayConfig) bool {
// EpayConfig 易支付配置
type EpayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
AppId string `json:"app_id"` // 商户 ID
PrivateKey string `json:"private_key"` // 私钥
ApiURL string `json:"api_url"` // z支付 API 网关
Domain string `json:"domain"` // 支付回调域名
Enabled bool `json:"enabled,omitempty"` // 是否启用该支付通道
AppId string `json:"app_id,omitempty"` // 商户 ID
PrivateKey string `json:"private_key,omitempty"` // 私钥
ApiURL string `json:"api_url,omitempty"` // z支付 API 网关
Domain string `json:"domain,omitempty"` // 支付回调域名
}
func (c *EpayConfig) Equal(other *EpayConfig) bool {

View File

@@ -8,23 +8,23 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type SMSConfig struct {
Active string `json:"active"`
Ali SmsConfigAli `json:"aliyun"`
Bao SmsConfigBao `json:"bao"`
Active string `json:"active,omitempty"`
Ali SmsConfigAli `json:"aliyun,omitempty"`
Bao SmsConfigBao `json:"bao,omitempty"`
}
// SmsConfigAli 阿里云短信平台配置
type SmsConfigAli struct {
AccessKey string `json:"access_key"`
AccessSecret string `json:"access_secret"`
Sign string `json:"sign"` // 短信签名
CodeTempId string `json:"code_temp_id"` // 验证码短信模板 ID
AccessKey string `json:"access_key,omitempty"`
AccessSecret string `json:"access_secret,omitempty"`
Sign string `json:"sign,omitempty"` // 短信签名
CodeTempId string `json:"code_temp_id,omitempty"` // 验证码短信模板 ID
}
// SmsConfigBao 短信宝平台配置
type SmsConfigBao struct {
Username string `json:"username"` //短信宝平台注册的用户名
Password string `json:"password"` //短信宝平台注册的密码
Sign string `json:"sign"` // 短信签名
CodeTemplate string `json:"code_template"` // 验证码短信模板 匹配
Username string `json:"username,omitempty"` //短信宝平台注册的用户名
Password string `json:"password,omitempty"` //短信宝平台注册的密码
Sign string `json:"sign,omitempty"` // 短信签名
CodeTemplate string `json:"code_template,omitempty"` // 验证码短信模板 匹配
}

View File

@@ -8,12 +8,12 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type SmtpConfig struct {
UseTls bool `json:"use_tls"` // 是否使用 TLS 发送
Host string `json:"host"` // 邮件服务器地址
Port int `json:"port"` // 邮件服务器端口
AppName string `json:"app_name"` // 应用名称
From string `json:"from"` // 发件人邮箱地址
Password string `json:"password"` // 发件人邮箱密码
UseTls bool `json:"use_tls,omitempty"` // 是否使用 TLS 发送
Host string `json:"host,omitempty"` // 邮件服务器地址
Port int `json:"port,omitempty"` // 邮件服务器端口
AppName string `json:"app_name,omitempty"` // 应用名称
From string `json:"from,omitempty"` // 发件人邮箱地址
Password string `json:"password,omitempty"` // 发件人邮箱密码
}
func (s *SmtpConfig) Equal(other *SmtpConfig) bool {

View File

@@ -0,0 +1,45 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import "sync"
// UserLockManager 提供基于用户ID的TryLock功能确保同一用户并发请求串行化
type UserLockManager struct {
mu sync.Mutex
locks map[uint]bool
}
func NewUserLockManager() *UserLockManager {
return &UserLockManager{mu: sync.Mutex{}, locks: make(map[uint]bool)}
}
// TryLock 尝试为指定用户加锁。若已被占用返回 false
func (m *UserLockManager) TryLock(userId uint) bool {
if userId == 0 {
return true
}
m.mu.Lock()
defer m.mu.Unlock()
if m.locks[userId] {
return false
}
m.locks[userId] = true
return true
}
// Unlock 释放指定用户的锁
func (m *UserLockManager) Unlock(userId uint) {
if userId == 0 {
return
}
m.mu.Lock()
delete(m.locks, userId)
m.mu.Unlock()
}

View File

@@ -33,6 +33,7 @@ require (
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.3.1
github.com/syndtr/goleveldb v1.0.0
github.com/volcengine/volcengine-go-sdk v1.1.34
golang.org/x/image v0.15.0
)
@@ -50,6 +51,7 @@ require (
github.com/tklauser/numcpus v0.7.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.uber.org/mock v0.4.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
require (

View File

@@ -100,6 +100,7 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
@@ -110,6 +111,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
@@ -259,6 +261,8 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
github.com/volcengine/volcengine-go-sdk v1.1.34 h1:ha90JycCCTJNCse0UDziBgBsuX98ITOrkwYlDWcm7NI=
github.com/volcengine/volcengine-go-sdk v1.1.34/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
@@ -390,6 +394,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -31,7 +31,7 @@ import (
var logger = logger2.GetLogger()
const SuperManagerID = 1
const SuperUsername = "admin"
type ManagerHandler struct {
handler.BaseHandler
@@ -94,7 +94,7 @@ func (h *ManagerHandler) Login(c *gin.Context) {
}
// 超级管理员默认是ID:1
if manager.Id != SuperManagerID && manager.Status == false {
if manager.Username != SuperUsername && !manager.Status {
resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员")
return
}
@@ -125,7 +125,7 @@ func (h *ManagerHandler) Login(c *gin.Context) {
IsSuperAdmin bool `json:"is_super_admin"`
Token string `json:"token"`
}{
IsSuperAdmin: manager.Id == 1,
IsSuperAdmin: manager.Username == SuperUsername,
Token: tokenString,
}
@@ -227,12 +227,19 @@ func (h *ManagerHandler) Remove(c *gin.Context) {
return
}
if id == SuperManagerID {
var user model.AdminUser
res := h.DB.Where("id", id).First(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
if user.Username == SuperUsername {
resp.ERROR(c, "超级管理员不能删除")
return
}
res := h.DB.Where("id", id).Delete(&model.AdminUser{})
res = h.DB.Where("id", id).Delete(&model.AdminUser{})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -263,8 +270,14 @@ func (h *ManagerHandler) Enable(c *gin.Context) {
// ResetPass 重置密码
func (h *ManagerHandler) ResetPass(c *gin.Context) {
id := h.GetLoginUserId(c)
if id != SuperManagerID {
id := h.GetAdminId(c)
var user model.AdminUser
res := h.DB.Where("id", id).First(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
if user.Username != SuperUsername {
resp.ERROR(c, "只有超级管理员能够进行该操作")
return
}
@@ -278,13 +291,6 @@ func (h *ManagerHandler) ResetPass(c *gin.Context) {
return
}
var user model.AdminUser
res := h.DB.Where("id", data.Id).First(&user)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.DB.Updates(&user)

View File

@@ -368,10 +368,7 @@ func (h *ConfigHandler) UpdateWxLogin(c *gin.Context) {
return
}
if data.Enabled {
h.wxLoginService.UpdateConfig(data)
}
h.wxLoginService.UpdateConfig(data)
h.sysConfig.WxLogin = data
resp.SUCCESS(c, data)
}

View File

@@ -131,7 +131,7 @@ func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
continue // 跳过不存在的
}
tx := h.DB.Begin()
if job.Status != model.JMTaskStatusSuccess && job.Power > 0 {
if job.Status != types.JMTaskStatusSuccess && job.Power > 0 {
remark := fmt.Sprintf("任务未成功退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
@@ -172,7 +172,7 @@ func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
// Stats 获取统计信息
func (h *AdminJimengHandler) Stats(c *gin.Context) {
type StatResult struct {
Status model.JMTaskStatus `json:"status"`
Status types.JMTaskStatus `json:"status"`
Count int64 `json:"count"`
}
@@ -198,13 +198,13 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
for _, stat := range stats {
result["totalTasks"] = result["totalTasks"].(int64) + stat.Count
switch stat.Status {
case model.JMTaskStatusInQueue:
case types.JMTaskStatusInQueue:
result["pendingTasks"] = stat.Count
case model.JMTaskStatusSuccess:
case types.JMTaskStatusSuccess:
result["completedTasks"] = stat.Count
case model.JMTaskStatusGenerating:
case types.JMTaskStatusGenerating:
result["processingTasks"] = stat.Count
case model.JMTaskStatusFailed:
case types.JMTaskStatusFailed:
result["failedTasks"] = stat.Count
}
}
@@ -231,29 +231,15 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
}
// 验证算力配置
if req.Power.TextToImage <= 0 {
resp.ERROR(c, "文生图算力必须大于0")
if len(req.Powers) == 0 {
resp.ERROR(c, "请至少配置一个模型的积分")
return
}
if req.Power.ImageToImage <= 0 {
resp.ERROR(c, "图生图算力必须大于0")
return
}
if req.Power.ImageEdit <= 0 {
resp.ERROR(c, "图片编辑算力必须大于0")
return
}
if req.Power.ImageEffects <= 0 {
resp.ERROR(c, "图片特效算力必须大于0")
return
}
if req.Power.TextToVideo <= 0 {
resp.ERROR(c, "文生视频算力必须大于0")
return
}
if req.Power.ImageToVideo <= 0 {
resp.ERROR(c, "图生视频算力必须大于0")
return
for key, val := range req.Powers {
if val <= 0 {
resp.ERROR(c, fmt.Sprintf("模型 %s 的积分必须大于0", key))
return
}
}
// 保存配置

View File

@@ -69,6 +69,7 @@ type ChatHandler struct {
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
userService *service.UserService
moderationManager *moderation.ServiceManager
userLocks *types.UserLockManager
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService, moderationManager *moderation.ServiceManager) *ChatHandler {
@@ -80,6 +81,7 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
userService: userService,
moderationManager: moderationManager,
userLocks: types.NewUserLockManager(),
}
}
@@ -120,6 +122,14 @@ func (h *ChatHandler) Chat(c *gin.Context) {
return
}
// 用户级并发锁,确保同一用户同时只有一个对话请求
if !h.userLocks.TryLock(input.UserId) {
pushMessage(c, ChatEventError, "您有一个对话请求正在进行中,请稍后再试或先停止当前生成!")
c.Abort()
return
}
defer h.userLocks.Unlock(input.UserId)
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
@@ -262,9 +272,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
if h.App.SysConfig.Base.ContextDeep > 0 {
var historyMessages []model.ChatMessage
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
if input.LastMsgId > 0 { // 重新生成逻辑
if input.LastMsgId > 0 { // 重新生成和编辑逻辑
var lastMessage model.ChatMessage
err = dbSession.Where("id <= ?", input.LastMsgId).Where("type", types.PromptMsg).First(&lastMessage).Error
err = dbSession.Where("id < ?", input.LastMsgId).Where("type", types.ReplyMsg).Order("id DESC").First(&lastMessage).Error
if err != nil {
input.LastMsgId = 0
} else {
@@ -272,7 +282,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
}
dbSession = dbSession.Where("id < ?", input.LastMsgId)
// 删除对应的聊天记录
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
h.DB.Debug().Where("chat_id", input.ChatId).Where("id > ?", input.LastMsgId).Delete(&model.ChatMessage{})
}
err = dbSession.Limit(h.App.SysConfig.Base.ContextDeep).Order("id DESC").Find(&historyMessages).Error
if err == nil {

View File

@@ -1,6 +1,7 @@
package handler
import (
"errors"
"fmt"
"geekai/core"
"geekai/core/middleware"
@@ -38,49 +39,28 @@ func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *go
// RegisterRoutes 注册路由,新增统一任务接口
func (h *JimengHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/jimeng/")
group.GET("power-config", h.GetPowerConfig)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("task", h.CreateTask)
group.GET("power-config", h.GetPowerConfig)
group.POST("jobs", h.Jobs)
group.GET("remove", h.Remove)
group.GET("retry", h.Retry)
}
}
// JimengTaskRequest 统一任务请求结构体
// 支持所有生图和生成视频类型
type JimengTaskRequest struct {
TaskType string `json:"task_type" binding:"required"`
Prompt string `json:"prompt"`
ImageInput string `json:"image_input"`
ImageUrls []string `json:"image_urls"`
BinaryDataBase64 []string `json:"binary_data_base64"`
Scale float64 `json:"scale"`
Width int `json:"width"`
Height int `json:"height"`
Gpen float64 `json:"gpen"`
Skin float64 `json:"skin"`
SkinUnifi float64 `json:"skin_unifi"`
GenMode string `json:"gen_mode"`
Seed int64 `json:"seed"`
UsePreLLM bool `json:"use_pre_llm"`
TemplateId string `json:"template_id"`
AspectRatio string `json:"aspect_ratio"`
}
// CreateTask 统一任务创建接口
func (h *JimengHandler) CreateTask(c *gin.Context) {
var req JimengTaskRequest
var req types.JimengTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 文本审核
if h.App.SysConfig.Moderation.Enable {
if h.App.SysConfig.Moderation.Enable && req.Prompt != "" {
moderationResult, err := h.moderationManager.GetService().Moderate(req.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
@@ -103,136 +83,21 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
}
// 新增:除图像特效外,其他任务类型必须有提示词
if req.TaskType != "image_effects" && req.Prompt == "" {
resp.ERROR(c, "提示词不能为空")
if req.Prompt == "" && len(req.ImageUrls) == 0 {
resp.ERROR(c, "提示词和图片不能同时为空")
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if req.Width == 0 {
req.Width = 1328
}
if req.Height == 0 {
req.Height = 1328
}
if req.Seed == 0 {
req.Seed = -1
}
var powerCost int
var taskType model.JMTaskType
var params map[string]any
var reqKey string
var modelName string
switch req.TaskType {
case "text_to_image":
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToImage)
taskType = model.JMTaskTypeTextToImage
reqKey = jimeng.ReqKeyTextToImage
modelName = "即梦文生图"
if req.Scale == 0 {
req.Scale = 2.5
}
params = map[string]any{
"seed": req.Seed,
"scale": req.Scale,
"width": req.Width,
"height": req.Height,
"use_pre_llm": req.UsePreLLM,
}
case "image_to_image":
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToImage)
taskType = model.JMTaskTypeImageToImage
reqKey = jimeng.ReqKeyImageToImagePortrait
modelName = "即梦图生图"
if req.Gpen == 0 {
req.Gpen = 0.4
}
if req.Skin == 0 {
req.Skin = 0.3
}
if req.GenMode == "" {
if req.Prompt != "" {
req.GenMode = jimeng.GenModeCreative
} else {
req.GenMode = jimeng.GenModeReference
}
}
params = map[string]any{
"image_input": req.ImageInput,
"width": req.Width,
"height": req.Height,
"gpen": req.Gpen,
"skin": req.Skin,
"skin_unifi": req.SkinUnifi,
"gen_mode": req.GenMode,
"seed": req.Seed,
}
case "image_edit":
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEdit)
taskType = model.JMTaskTypeImageEdit
reqKey = jimeng.ReqKeyImageEdit
modelName = "即梦图像编辑"
if req.Scale == 0 {
req.Scale = 0.5
}
params = map[string]any{
"seed": req.Seed,
"scale": req.Scale,
}
params["image_urls"] = []string{req.ImageInput}
case "image_effects":
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
taskType = model.JMTaskTypeImageEffects
reqKey = jimeng.ReqKeyImageEffects
modelName = "即梦图像特效"
if req.Width == 0 {
req.Width = 1328
}
if req.Height == 0 {
req.Height = 1328
}
params = map[string]any{
"image_input1": req.ImageInput,
"template_id": req.TemplateId,
"width": req.Width,
"height": req.Height,
}
case "text_to_video":
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToVideo)
taskType = model.JMTaskTypeTextToVideo
reqKey = jimeng.ReqKeyTextToVideo
modelName = "即梦文生视频"
if req.AspectRatio == "" {
req.AspectRatio = jimeng.AspectRatio16_9
}
params = map[string]any{
"seed": req.Seed,
"aspect_ratio": req.AspectRatio,
}
case "image_to_video":
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToVideo)
taskType = model.JMTaskTypeImageToVideo
reqKey = jimeng.ReqKeyImageToVideo
modelName = "即梦图生视频"
params = map[string]any{
"seed": req.Seed,
"aspect_ratio": req.AspectRatio,
}
if len(req.ImageUrls) > 0 {
params["image_urls"] = req.ImageUrls
}
if len(req.BinaryDataBase64) > 0 {
params["binary_data_base64"] = req.BinaryDataBase64
}
default:
resp.ERROR(c, "不支持的任务类型")
// 获取算力消耗
powerCost, err := h.getTaskPower(req)
if err != nil {
resp.ERROR(c, "计算任务消耗积分失败: "+err.Error())
return
}
@@ -240,16 +105,9 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
return
}
req.Power = powerCost
taskReq := &jimeng.CreateTaskRequest{
Type: taskType,
Prompt: req.Prompt,
Params: params,
ReqKey: reqKey,
Power: powerCost,
}
job, err := h.jimengService.CreateTask(user.Id, taskReq)
job, err := h.jimengService.CreateTask(user.Id, &req)
if err != nil {
logger.Errorf("create jimeng task failed: %v", err)
resp.ERROR(c, "创建任务失败")
@@ -258,11 +116,42 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
Type: types.PowerConsume,
Model: "jimeng",
Remark: fmt.Sprintf("%s任务ID%d", modelName, job.Id),
Model: job.ReqKey,
Remark: h.getTaskRemark(req, job.Id),
})
resp.SUCCESS(c, job)
resp.SUCCESS(c)
}
func (h *JimengHandler) getTaskRemark(req types.JimengTaskRequest, jobId uint) string {
remark := fmt.Sprintf("即梦任务%s任务ID%d", req.ReqKey, jobId)
perUnit, ok := h.App.SysConfig.Jimeng.Powers[req.ReqKey]
if !ok || perUnit <= 0 {
return remark // Fallback if power not found or invalid
}
switch req.TaskType {
case types.JMTaskTypeImage:
remark = fmt.Sprintf("即梦图片生成任务ID%d%d积分/张", jobId, perUnit)
case types.JMTaskTypeVideo:
seconds := 0
if perUnit > 0 {
seconds = req.Power / perUnit
}
remark = fmt.Sprintf("即梦视频生成任务ID%d%d积分/秒, %d秒", jobId, perUnit, seconds)
case types.JMTaskTypeVirtualHuman:
seconds := 0
if perUnit > 0 {
seconds = req.Power / perUnit
}
remark = fmt.Sprintf("即梦数字人视频生成任务ID%d%d积分/秒, %d秒", jobId, perUnit, seconds)
case types.JMTaskTypeActionTransfer:
seconds := 0
if perUnit > 0 {
seconds = req.Power / perUnit
}
remark = fmt.Sprintf("即梦视频动作迁移任务ID%d%d积分/秒, %d秒", jobId, perUnit, seconds)
}
return remark
}
// Jobs 获取任务列表
@@ -287,17 +176,13 @@ func (h *JimengHandler) Jobs(c *gin.Context) {
switch req.Filter {
case "image":
query = query.Where("type IN (?)", []model.JMTaskType{
model.JMTaskTypeTextToImage,
model.JMTaskTypeImageToImage,
model.JMTaskTypeImageEdit,
model.JMTaskTypeImageEffects,
})
query = query.Where("type = ?", types.JMTaskTypeImage)
case "video":
query = query.Where("type IN (?)", []model.JMTaskType{
model.JMTaskTypeTextToVideo,
model.JMTaskTypeImageToVideo,
})
query = query.Where("type = ?", types.JMTaskTypeVideo)
case "virtual_human":
query = query.Where("type = ?", types.JMTaskTypeVirtualHuman)
case "action_transfer":
query = query.Where("type = ?", types.JMTaskTypeActionTransfer)
}
if len(req.Ids) > 0 {
@@ -357,7 +242,7 @@ func (h *JimengHandler) Remove(c *gin.Context) {
}
// 正在运行中的任务不能删除
if job.Status == model.JMTaskStatusGenerating || job.Status == model.JMTaskStatusInQueue {
if job.Status == types.JMTaskStatusGenerating || job.Status == types.JMTaskStatusInQueue {
resp.ERROR(c, "正在运行中的任务不能删除,否则无法退回算力")
return
}
@@ -370,10 +255,11 @@ func (h *JimengHandler) Remove(c *gin.Context) {
}
// 失败任务删除后退回算力
if job.Status != model.JMTaskStatusFailed {
if job.Status == types.JMTaskStatusFailed {
logger.Infof("delete jimeng job failed, refund power: %d", job.Power)
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "jimeng",
Model: job.ReqKey,
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
})
if err != nil {
@@ -411,13 +297,13 @@ func (h *JimengHandler) Retry(c *gin.Context) {
}
// 只有失败的任务才能重试
if job.Status != model.JMTaskStatusFailed {
if job.Status != types.JMTaskStatusFailed {
resp.ERROR(c, "只有失败的任务才能重试")
return
}
// 重置任务状态
if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JMTaskStatusInQueue, ""); err != nil {
if err := h.jimengService.UpdateJobStatus(uint(jobId), types.JMTaskStatusInQueue, ""); err != nil {
logger.Errorf("reset job status failed: %v", err)
resp.ERROR(c, "重置任务状态失败")
return
@@ -433,25 +319,49 @@ func (h *JimengHandler) Retry(c *gin.Context) {
resp.SUCCESS(c, gin.H{"message": "重试任务已提交"})
}
// getPowerFromConfig 从配置中获取指定类型的算力消耗
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
func (h *JimengHandler) getTaskPower(req types.JimengTaskRequest) (int, error) {
logger.Debugf("getTaskPower req: %+v", req)
config := h.App.SysConfig.Jimeng
switch taskType {
case model.JMTaskTypeTextToImage:
return config.Power.TextToImage
case model.JMTaskTypeImageToImage:
return config.Power.ImageToImage
case model.JMTaskTypeImageEdit:
return config.Power.ImageEdit
case model.JMTaskTypeImageEffects:
return config.Power.ImageEffects
case model.JMTaskTypeTextToVideo:
return config.Power.TextToVideo
case model.JMTaskTypeImageToVideo:
return config.Power.ImageToVideo
basePower, ok := config.Powers[req.ReqKey]
if !ok || basePower <= 0 {
return 0, errors.New("未配置模型积分或配置不合法")
}
switch req.TaskType {
case types.JMTaskTypeImage:
return basePower, nil
case types.JMTaskTypeVideo:
if req.Duration == 0 {
return 0, errors.New("视频时长不能为0")
}
return basePower * req.Duration, nil
case types.JMTaskTypeVirtualHuman:
if req.AudioURL == "" {
return 0, errors.New("音频URL不能为空")
}
audioDuration, err := utils.AudioDurationFromURL(req.AudioURL)
if err != nil {
return 0, err
}
seconds := int(audioDuration.Seconds())
if seconds <= 0 {
return 0, errors.New("音频时长无效")
}
return basePower * seconds, nil
case types.JMTaskTypeActionTransfer:
if req.VideoURL == "" {
return 0, errors.New("视频URL不能为空")
}
videoDuration, err := utils.VideoDurationMP4FromURL(req.VideoURL)
if err != nil {
return 0, err
}
seconds := int(videoDuration.Seconds())
if seconds <= 0 {
return 0, errors.New("视频时长无效")
}
return basePower * seconds, nil
default:
return 10
return 0, errors.New("任务类型不支持")
}
}
@@ -459,11 +369,6 @@ func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
config := h.App.SysConfig.Jimeng
resp.SUCCESS(c, gin.H{
"text_to_image": config.Power.TextToImage,
"image_to_image": config.Power.ImageToImage,
"image_edit": config.Power.ImageEdit,
"image_effects": config.Power.ImageEffects,
"text_to_video": config.Power.TextToVideo,
"image_to_video": config.Power.ImageToVideo,
"powers": config.Powers,
})
}

View File

@@ -1,38 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIDszCCApugAwIBAgIQICMRB0rBU2/rZJbfJGMYIzANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE
BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv
biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDYzNTQxWhcNMjQxMTA2MDYzNTQxWjCB
hDELMAkGA1UEBhMCQ04xHzAdBgNVBAoMFm1ib25meTkwMTVAc2FuZGJveC5jb20xDzANBgNVBAsM
BkFsaXBheTFDMEEGA1UEAww65pSv5LuY5a6dKOS4reWbvSnnvZHnu5zmioDmnK/mnInpmZDlhazl
j7gtMjA4ODcyMTAyMDc1MDU4MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKsoKcw5
sxaiyV7mpWzDtnQ1K518eQLP0+dJlZAf06aBep/Aj9DIqrba/k7DHt8dKQvILMLAMpN1+2IRxbaO
yxMa/laj3lZ1eHrB6F077O3D62oHcE3noZtXL0N1zZAxpmkNmYIHeLZS2oLMS4ANu47O/wpDC7BV
HjdpZugtdPJ4mxdCpM9GDdLs7W4s5QI4PUPK4skFNMFoKI+0cYP/9ju87UP//IHC/K510GWNl+Gn
Cvgag3AmiIB0utJNsGhxm6zT1T9tUWjW9iz/BxBKiPatsCX9VpPQzGnW7ZonRQtiZSokIlP2IPvl
H5DcwpWUz3/LUY0SmKxnKOEYeOOqCW8CAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3
DQEBCwUAA4IBAQAtgxF2EzjOndEFxBUD9tFwcSt6XKGggOp52oft1pvynPg4ALTLafOtfEPDrFBH
PwpYrSu9s9C8NJtaA2HrlCfBjIuwEFTXiN+HPvS0SwSPKt9AXEiTcOF8vDcGamEen8QI4fo5Jia7
2VRKkerkww5/+FzSaVO7ZUKuL80M1QJStmAZc8kPPwdYOTTW2bGf8BcmSDL6SPElBkt7tCCRd4sn
+jq4cZ0yb2i77rBZCwHcTvfTqIBblPwLv4uGvg3+83BxIB5w6Kqp06bKEAPmobFY5IVHa+ON0/qi
BXxXr+WQ3piKRVQEN64+PTAjSc67Ix1umvpLl3Ko6Ry7NJmpDcUn
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIDszCCApugAwIBAgIQIBkIGbgVxq210KxLJ+YA/TANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UE
BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxJTAjBgNVBAsMHENlcnRpZmljYXRpb24gQXV0
aG9yaXR5IHRlc3QxNjA0BgNVBAMMLUFudCBGaW5hbmNpYWwgQ2VydGlmaWNhdGlvbiBBdXRob3Jp
dHkgUjEgdGVzdDAeFw0xOTA4MTkxMTE2MDBaFw0yNDA4MDExMTE2MDBaMIGRMQswCQYDVQQGEwJD
TjEbMBkGA1UECgwSQW50IEZpbmFuY2lhbCB0ZXN0MSUwIwYDVQQLDBxDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSB0ZXN0MT4wPAYDVQQDDDVBbnQgRmluYW5jaWFsIENlcnRpZmljYXRpb24gQXV0aG9y
aXR5IENsYXNzIDIgUjEgdGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMh4FKYO
ZyRQHD6eFbPKZeSAnrfjfU7xmS9Yoozuu+iuqZlb6Z0SPLUqqTZAFZejOcmr07ln/pwZxluqplxC
5+B48End4nclDMlT5HPrDr3W0frs6Xsa2ZNcyil/iKNB5MbGll8LRAxntsKvZZj6vUTMb705gYgm
VUMILwi/ZxKTQqBtkT/kQQ5y6nOZsj7XI5rYdz6qqOROrpvS/d7iypdHOMIM9Iz9DlL1mrCykbBi
t25y+gTeXmuisHUwqaRpwtCGK4BayCqxRGbNipe6W73EK9lBrrzNtTr9NaysesT/v+l25JHCL9tG
wpNr1oWFzk4IHVOg0ORiQ6SUgxZUTYcCAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3
DQEBCwUAA4IBAQBWThEoIaQoBX2YeRY/I8gu6TYnFXtyuCljANnXnM38ft+ikhE5mMNgKmJYLHvT
yWWWgwHoSAWEuml7EGbE/2AK2h3k0MdfiWLzdmpPCRG/RJHk6UB1pMHPilI+c0MVu16OPpKbg5Vf
LTv7dsAB40AzKsvyYw88/Ezi1osTXo6QQwda7uefvudirtb8FcQM9R66cJxl3kt1FXbpYwheIm/p
j1mq64swCoIYu4NrsUYtn6CV542DTQMI5QdXkn+PzUUly8F6kDp+KpMNd0avfWNL5+O++z+F5Szy
1CPta1D7EQ/eYmMP+mOQ35oifWIoFCpN6qQVBS/Hob1J/UUyg7BW
-----END CERTIFICATE-----

View File

@@ -1,88 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIBszCCAVegAwIBAgIIaeL+wBcKxnswDAYIKoEcz1UBg3UFADAuMQswCQYDVQQG
EwJDTjEOMAwGA1UECgwFTlJDQUMxDzANBgNVBAMMBlJPT1RDQTAeFw0xMjA3MTQw
MzExNTlaFw00MjA3MDcwMzExNTlaMC4xCzAJBgNVBAYTAkNOMQ4wDAYDVQQKDAVO
UkNBQzEPMA0GA1UEAwwGUk9PVENBMFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAE
MPCca6pmgcchsTf2UnBeL9rtp4nw+itk1Kzrmbnqo05lUwkwlWK+4OIrtFdAqnRT
V7Q9v1htkv42TsIutzd126NdMFswHwYDVR0jBBgwFoAUTDKxl9kzG8SmBcHG5Yti
W/CXdlgwDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFEwysZfZ
MxvEpgXBxuWLYlvwl3ZYMAwGCCqBHM9VAYN1BQADSAAwRQIgG1bSLeOXp3oB8H7b
53W+CKOPl2PknmWEq/lMhtn25HkCIQDaHDgWxWFtnCrBjH16/W3Ezn7/U/Vjo5xI
pDoiVhsLwg==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIF0zCCA7ugAwIBAgIIH8+hjWpIDREwDQYJKoZIhvcNAQELBQAwejELMAkGA1UE
BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNVBAsMF0NlcnRpZmlj
YXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5jaWFsIENlcnRpZmlj
YXRpb24gQXV0aG9yaXR5IFIxMB4XDTE4MDMyMTEzNDg0MFoXDTM4MDIyODEzNDg0
MFowejELMAkGA1UEBhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNV
BAsMF0NlcnRpZmljYXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5j
aWFsIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFIxMIICIjANBgkqhkiG9w0BAQEF
AAOCAg8AMIICCgKCAgEAtytTRcBNuur5h8xuxnlKJetT65cHGemGi8oD+beHFPTk
rUTlFt9Xn7fAVGo6QSsPb9uGLpUFGEdGmbsQ2q9cV4P89qkH04VzIPwT7AywJdt2
xAvMs+MgHFJzOYfL1QkdOOVO7NwKxH8IvlQgFabWomWk2Ei9WfUyxFjVO1LVh0Bp
dRBeWLMkdudx0tl3+21t1apnReFNQ5nfX29xeSxIhesaMHDZFViO/DXDNW2BcTs6
vSWKyJ4YIIIzStumD8K1xMsoaZBMDxg4itjWFaKRgNuPiIn4kjDY3kC66Sl/6yTl
YUz8AybbEsICZzssdZh7jcNb1VRfk79lgAprm/Ktl+mgrU1gaMGP1OE25JCbqli1
Pbw/BpPynyP9+XulE+2mxFwTYhKAwpDIDKuYsFUXuo8t261pCovI1CXFzAQM2w7H
DtA2nOXSW6q0jGDJ5+WauH+K8ZSvA6x4sFo4u0KNCx0ROTBpLif6GTngqo3sj+98
SZiMNLFMQoQkjkdN5Q5g9N6CFZPVZ6QpO0JcIc7S1le/g9z5iBKnifrKxy0TQjtG
PsDwc8ubPnRm/F82RReCoyNyx63indpgFfhN7+KxUIQ9cOwwTvemmor0A+ZQamRe
9LMuiEfEaWUDK+6O0Gl8lO571uI5onYdN1VIgOmwFbe+D8TcuzVjIZ/zvHrAGUcC
AwEAAaNdMFswCwYDVR0PBAQDAgEGMAwGA1UdEwQFMAMBAf8wHQYDVR0OBBYEFF90
tATATwda6uWx2yKjh0GynOEBMB8GA1UdIwQYMBaAFF90tATATwda6uWx2yKjh0Gy
nOEBMA0GCSqGSIb3DQEBCwUAA4ICAQCVYaOtqOLIpsrEikE5lb+UARNSFJg6tpkf
tJ2U8QF/DejemEHx5IClQu6ajxjtu0Aie4/3UnIXop8nH/Q57l+Wyt9T7N2WPiNq
JSlYKYbJpPF8LXbuKYG3BTFTdOVFIeRe2NUyYh/xs6bXGr4WKTXb3qBmzR02FSy3
IODQw5Q6zpXj8prYqFHYsOvGCEc1CwJaSaYwRhTkFedJUxiyhyB5GQwoFfExCVHW
05ZFCAVYFldCJvUzfzrWubN6wX0DD2dwultgmldOn/W/n8at52mpPNvIdbZb2F41
T0YZeoWnCJrYXjq/32oc1cmifIHqySnyMnavi75DxPCdZsCOpSAT4j4lAQRGsfgI
kkLPGQieMfNNkMCKh7qjwdXAVtdqhf0RVtFILH3OyEodlk1HYXqX5iE5wlaKzDop
PKwf2Q3BErq1xChYGGVS+dEvyXc/2nIBlt7uLWKp4XFjqekKbaGaLJdjYP5b2s7N
1dM0MXQ/f8XoXKBkJNzEiM3hfsU6DOREgMc1DIsFKxfuMwX3EkVQM1If8ghb6x5Y
jXayv+NLbidOSzk4vl5QwngO/JYFMkoc6i9LNwEaEtR9PhnrdubxmrtM+RjfBm02
77q3dSWFESFQ4QxYWew4pHE0DpWbWy/iMIKQ6UZ5RLvB8GEcgt8ON7BBJeMc+Dyi
kT9qhqn+lw==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIICiDCCAgygAwIBAgIIQX76UsB/30owDAYIKoZIzj0EAwMFADB6MQswCQYDVQQG
EwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UECwwXQ2VydGlmaWNh
dGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNpYWwgQ2VydGlmaWNh
dGlvbiBBdXRob3JpdHkgRTEwHhcNMTkwNDI4MTYyMDQ0WhcNNDkwNDIwMTYyMDQ0
WjB6MQswCQYDVQQGEwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UE
CwwXQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNp
YWwgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgRTEwdjAQBgcqhkjOPQIBBgUrgQQA
IgNiAASCCRa94QI0vR5Up9Yr9HEupz6hSoyjySYqo7v837KnmjveUIUNiuC9pWAU
WP3jwLX3HkzeiNdeg22a0IZPoSUCpasufiLAnfXh6NInLiWBrjLJXDSGaY7vaokt
rpZvAdmjXTBbMAsGA1UdDwQEAwIBBjAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRZ
4ZTgDpksHL2qcpkFkxD2zVd16TAfBgNVHSMEGDAWgBRZ4ZTgDpksHL2qcpkFkxD2
zVd16TAMBggqhkjOPQQDAwUAA2gAMGUCMQD4IoqT2hTUn0jt7oXLdMJ8q4vLp6sg
wHfPiOr9gxreb+e6Oidwd2LDnC4OUqCWiF8CMAzwKs4SnDJYcMLf2vpkbuVE4dTH
Rglz+HGcTLWsFs4KxLsq7MuU+vJTBUeDJeDjdA==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIDxTCCAq2gAwIBAgIUEMdk6dVgOEIS2cCP0Q43P90Ps5YwDQYJKoZIhvcNAQEF
BQAwajELMAkGA1UEBhMCQ04xEzARBgNVBAoMCmlUcnVzQ2hpbmExHDAaBgNVBAsM
E0NoaW5hIFRydXN0IE5ldHdvcmsxKDAmBgNVBAMMH2lUcnVzQ2hpbmEgQ2xhc3Mg
MiBSb290IENBIC0gRzMwHhcNMTMwNDE4MDkzNjU2WhcNMzMwNDE4MDkzNjU2WjBq
MQswCQYDVQQGEwJDTjETMBEGA1UECgwKaVRydXNDaGluYTEcMBoGA1UECwwTQ2hp
bmEgVHJ1c3QgTmV0d29yazEoMCYGA1UEAwwfaVRydXNDaGluYSBDbGFzcyAyIFJv
b3QgQ0EgLSBHMzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOPPShpV
nJbMqqCw6Bz1kehnoPst9pkr0V9idOwU2oyS47/HjJXk9Rd5a9xfwkPO88trUpz5
4GmmwspDXjVFu9L0eFaRuH3KMha1Ak01citbF7cQLJlS7XI+tpkTGHEY5pt3EsQg
wykfZl/A1jrnSkspMS997r2Gim54cwz+mTMgDRhZsKK/lbOeBPpWtcFizjXYCqhw
WktvQfZBYi6o4sHCshnOswi4yV1p+LuFcQ2ciYdWvULh1eZhLxHbGXyznYHi0dGN
z+I9H8aXxqAQfHVhbdHNzi77hCxFjOy+hHrGsyzjrd2swVQ2iUWP8BfEQqGLqM1g
KgWKYfcTGdbPB1MCAwEAAaNjMGEwHQYDVR0OBBYEFG/oAMxTVe7y0+408CTAK8hA
uTyRMB8GA1UdIwQYMBaAFG/oAMxTVe7y0+408CTAK8hAuTyRMA8GA1UdEwEB/wQF
MAMBAf8wDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3DQEBBQUAA4IBAQBLnUTfW7hp
emMbuUGCk7RBswzOT83bDM6824EkUnf+X0iKS95SUNGeeSWK2o/3ALJo5hi7GZr3
U8eLaWAcYizfO99UXMRBPw5PRR+gXGEronGUugLpxsjuynoLQu8GQAeysSXKbN1I
UugDo9u8igJORYA+5ms0s5sCUySqbQ2R5z/GoceyI9LdxIVa1RjVX8pYOj8JFwtn
DJN3ftSFvNMYwRuILKuqUYSHc2GPYiHVflDh5nDymCMOQFcFG3WsEuB+EYQPFgIU
1DHmdZcz7Llx8UOZXX2JupWCYzK1XhJb+r4hK5ncf/w8qGtYlmyJpxk3hr1TfUJX
Yf4Zr0fJsGuv
-----END CERTIFICATE-----

View File

@@ -1,19 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIDmTCCAoGgAwIBAgIQICMRB2LW76yahgdg3IFNPDANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE
BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv
biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1
dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDU0NjE5WhcNMjQxMTExMDU0NjE5WjBr
MQswCQYDVQQGEwJDTjEfMB0GA1UECgwWbWJvbmZ5OTAxNUBzYW5kYm94LmNvbTEPMA0GA1UECwwG
QWxpcGF5MSowKAYDVQQDDCEyMDg4NzIxMDIwNzUwNTgxLTkwMjEwMDAxMzE2NTgwMjMwggEiMA0G
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCxihQPf1Q+g9ArgM46shVqL5sbRha/df95D1PsWyEq
ANmWmG4zZ+ksYDVQrc4KzhSRoi56sm/7TDFYTmM6bW99e/nKW58WxyZB4ie5qA3F4n17psPyDqb8
IokcQmCphSFDaXQD6AoXoLNtTM0vAI2cWxAgebZ/vsrdj5Ntjt+Rp3NYMCk1i5xovHcfILzLEGbX
QXoT9fo5AhHotTWa6xHVLPUGY9qwLzQxHzBmvy5ZMfnOfJkm/mDisTSqAUB59F3dzU/1ARVkEZ1w
Mgb4XohWBw6iurQfbMnH2mIomAAwwZVFv+sXDbL9yMbSMo/SjVsTQprn0Q0EnwLo7nmmOM6HAgMB
AAGjEjAQMA4GA1UdDwEB/wQEAwIE8DANBgkqhkiG9w0BAQsFAAOCAQEAn3Y4/C1h9R6ONsBqX3/q
XfHX7yX1FM0Y1x48X3/Yxk6HivAkTukhhhVYVKJsbrbzRqHDp9vhAP/FR6o6pAevaYMmLov0VMXU
7oAuetgkaYEYkDuNen5/Hpdhqi2vTtdT+q9w8zHJd6MDQ0aoHgIxpLKw5vof2R1N4fwSgNXMiXE5
kmllKQMem/+on2p+Sj80/2asxryHIGlH87qPzkffv+kIOkZthbTApTFLLjdVri2QHGe8/cc4xy01
/9iR3IUzNahotT41lJ4bMevBY7XMAS3n5ekyABN/9ZRJqhWdXgmFCRN/u56qd6lDgu7R2M2QUoyc
LuW5DfgRItKlmUB7sw==
-----END CERTIFICATE-----

View File

@@ -1 +0,0 @@
MIIEpQIBAAKCAQEAsYoUD39UPoPQK4DOOrIVai+bG0YWv3X/eQ9T7FshKgDZlphuM2fpLGA1UK3OCs4UkaIuerJv+0wxWE5jOm1vfXv5ylufFscmQeInuagNxeJ9e6bD8g6m/CKJHEJgqYUhQ2l0A+gKF6CzbUzNLwCNnFsQIHm2f77K3Y+TbY7fkadzWDApNYucaLx3HyC8yxBm10F6E/X6OQIR6LU1musR1Sz1BmPasC80MR8wZr8uWTH5znyZJv5g4rE0qgFAefRd3c1P9QEVZBGdcDIG+F6IVgcOorq0H2zJx9piKJgAMMGVRb/rFw2y/cjG0jKP0o1bE0Ka59ENBJ8C6O55pjjOhwIDAQABAoIBAFetNfz1R7hbxjlFshMAkVzQR8wvT9qbvl+dtzdZRcaFhu89NecDIP7+QDYor0FcxoGpU0TazDyRQyk2BQD8vHt+9zv9BVLtZLJSqoWgPbUFBi1DjS8EF2ka8RVYnn35NhUhhd7L//ftL88Bh673mfembQ9srDjoEy1Z01feoABAnCMkNFl986DmEwnarvEufXSDIgeN4ioMxha4NvfIPuI0zpVdV1O9sv+SGC+VEWZBtN3GNsaf4zS/f8FVGvTiU/Abz0gSw/iwSPHclDWQDTN3yFHf/tfqlzh0mH0WfhnuOBFWXzK+R7fbnM+asI9ttvzRcfpzgRGXdPcNcOv/6cECgYEA3DVqpi1k8MYfJixju6SG5gfyhM4VFksFmCMaNPgtatDMBKLMTgV/Ej6LXREojcy29uZl83F09pVlpd41eG39ULIPktixA/BqErQ2UaWh6kOxifycpu22Jh0r09hax6UgVrcBrrnCJEjcFsuJlrZvXQSzc3PBxjWy5gjabS5h9iECgYEAzmVAIh2frF01Y95zsLueAhhZwCtPanm6kf7ivR4r1plIX3b2sNRhWGmEHFgaCE6Braa0ogQ73Hd26kw4ZW+D6QMGC/zjCBEzDLLf++SjdVUHiY5AR4WHqXzq1jdAlsVyo9R661oAOp3lhiJVGLNXkHyEfEVPHsaxJh4osYSbX6cCgYEAx32Qx0i6eDFTyLZQB46uMrgiaVN04QRH5iJuvGvUYT8UhGKjaU8rZfDJOh+wOH2rhxMEaz1uc3C2bERY9mfWI4Ob/jFWc7YZsiYWS3Mcsuhubw4tMECLUg39RWZsHw8ls8kIuixIh6yFzhTH6YQOcRswIrhMZG8DScfdcSmiz2ECgYEAkWP1t5KSpkLKl11etcKUXfl1T8+yk9jIOowIgRw92WAFAWq2AH67TCKYM7dEL1HOO9tRJ0hAOt/U3ttuZtYVYBEHM26jJ02mXm2rJrA7DS4mrxmL4lYH6LbcXqZxU0Qnq4zEQgIWYzRTORf6Rfof1uJAGaJhR9bDd4yLMfGt2cUCgYEAo216Y61xOHUTA4AF1eekk+r+uOcQgQDvLXfs9FkDdJLk0mPG48/+eIYpPFnANJ/riF/DWOp8WGEe2IzA9yUFexzDbNQK8ha9kGcxaSAyiCwzjZ/t9/+hScDSV8kNqWSRSisu/YOFleEHbokT6mbLZ+gdqES8mUUanaEBzRQYGxo=

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.7 KiB

View File

@@ -1,15 +1,21 @@
package jimeng
import (
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"net/http"
"net/url"
"strings"
"time"
"github.com/volcengine/volc-sdk-golang/base"
"github.com/volcengine/volc-sdk-golang/service/visual"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/volcengine"
)
// Client 即梦API客户端
@@ -50,6 +56,22 @@ func (c *Client) UpdateConfig(config types.JimengConfig) error {
"Version": []string{"2022-08-31"},
},
},
"CVSubmitTask": {
Method: http.MethodPost,
Path: "/",
Query: url.Values{
"Action": []string{"CVSubmitTask"},
"Version": []string{"2022-08-31"},
},
},
"CVGetResult": {
Method: http.MethodPost,
Path: "/",
Query: url.Values{
"Action": []string{"CVGetResult"},
"Version": []string{"2022-08-31"},
},
},
"CVProcess": {
Method: http.MethodPost,
Path: "/",
@@ -71,6 +93,22 @@ func (c *Client) UpdateConfig(config types.JimengConfig) error {
return c.testConnection()
}
// GetErrorMessage 根据错误代码获取对应的错误信息
func GetErrorMessage(code int) string {
if message, exists := errorCodeMessages[code]; exists {
return message
}
return fmt.Sprintf("未知错误代码: %d", code)
}
// HandleResponseError 处理响应错误,根据错误代码返回详细的错误信息
func HandleResponseError(code int, message string) error {
if code == ECSuccess {
return nil
}
return errors.New(GetErrorMessage(code))
}
// testConnection 测试即梦AI连接
func (c *Client) testConnection() error {
@@ -80,7 +118,7 @@ func (c *Client) testConnection() error {
TaskId: "test_task_id_12345",
}
_, err := c.QueryTask(testReq)
_, err := c.QueryTask(testReq, ASyncActionGetResult)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
@@ -94,7 +132,7 @@ func (c *Client) testConnection() error {
}
// SubmitTask 提交异步任务
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
func (c *Client) SubmitTask(req map[string]any) (*SubmitTaskResponse, error) {
// 直接将请求转为map[string]interface{}
reqBodyBytes, err := json.Marshal(req)
if err != nil {
@@ -103,9 +141,14 @@ func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error)
// 直接使用序列化后的字节
jsonBody := reqBodyBytes
action := ASyncActionSubmit
if v, ok := req["action"]; ok {
action = v.(string)
delete(req, "action")
}
// 调用SDK的JSON方法
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncSubmitTask", nil, string(jsonBody))
respBody, statusCode, err := c.visual.Client.Json(action, nil, string(jsonBody))
if err != nil {
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
}
@@ -118,11 +161,70 @@ func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error)
return nil, fmt.Errorf("unmarshal response failed: %w", err)
}
// 检查响应错误代码
if err := HandleResponseError(result.Code, result.Message); err != nil {
return nil, err
}
return &result, nil
}
// 识别数字人主体
func (c *Client) AvatarRecognition(imgUrl string, reqKey string) error {
params := map[string]any{
"image_url": imgUrl,
"req_key": reqKey,
}
reqBodyBytes, err := json.Marshal(params)
if err != nil {
return fmt.Errorf("marshal request failed: %w", err)
}
// 调用SDK的JSON方法
respBody, statusCode, err := c.visual.Client.Json(SyncActionSubmit, nil, string(reqBodyBytes))
if err != nil {
return fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
}
// 解析响应
var result SubmitTaskResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return fmt.Errorf("unmarshal response failed: %w", err)
}
// 检查响应错误代码
if err := HandleResponseError(result.Code, result.Message); err != nil {
return err
}
// 等待任务完成
for {
resp, err := c.QueryTask(&QueryTaskRequest{
ReqKey: reqKey,
TaskId: result.Data.TaskId,
}, SyncActionGetResult)
if err != nil {
return fmt.Errorf("query task failed: %w", err)
}
if resp.Data.Status != types.JMTaskStatusDone {
time.Sleep(time.Second * 3)
continue
}
var respData map[string]int
if err := json.Unmarshal([]byte(resp.Data.RespData), &respData); err != nil {
return fmt.Errorf("unmarshal response failed: %w", err)
}
logger.Debugf("Jimeng AvatarRecognition Response: %+v", resp)
if respData["status"] == 1 {
return nil
} else {
return errors.New("不包含人、类人、拟人等主体")
}
}
}
// QueryTask 查询任务结果
func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
func (c *Client) QueryTask(req *QueryTaskRequest, action string) (*QueryTaskResponse, error) {
// 序列化请求
jsonBody, err := json.Marshal(req)
if err != nil {
@@ -130,7 +232,7 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
}
// 调用SDK的JSON方法
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncGetResult", nil, string(jsonBody))
respBody, statusCode, err := c.visual.Client.Json(action, nil, string(jsonBody))
if err != nil {
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
}
@@ -143,30 +245,37 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
return nil, fmt.Errorf("unmarshal response failed: %w", err)
}
return &result, nil
}
// SubmitSyncTask 提交同步任务(仅用于文生图)
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
// 序列化请求
jsonBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request failed: %w", err)
}
// 调用SDK的JSON方法
respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody))
if err != nil {
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
}
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
// 解析响应,同步任务直接返回结果
var result QueryTaskResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("unmarshal response failed: %w", err)
// 检查响应错误代码
if err := HandleResponseError(result.Code, result.Message); err != nil {
return nil, err
}
return &result, nil
}
// SubmitSyncImageTask 提交同步生图任务
func (c *Client) SubmitSyncImageTask(req types.JimengTaskRequest) (*model.ImagesResponse, error) {
// 配置火山引擎访问密钥目前只支持API Key验证
client := arkruntime.NewClientWithApiKey(c.config.ApiKey)
// 构造生图请求
sequentialImageGeneration := model.SequentialImageGeneration("disabled")
generateReq := model.GenerateImagesRequest{
Model: req.ReqKey, // 模型名称
Prompt: req.Prompt, // 提示词
Size: volcengine.String(req.Size), // 图片尺寸
SequentialImageGeneration: &sequentialImageGeneration, // 禁用序列生成
ResponseFormat: volcengine.String(model.GenerateImagesResponseFormatURL), // 响应格式为 URL
Watermark: volcengine.Bool(false), // 不添加水印
OptimizePrompt: volcengine.Bool(true), // 优化提示词
}
if len(req.ImageUrls) > 0 {
generateReq.Image = req.ImageUrls
}
// 调用生图 API
resp, err := client.GenerateImages(context.Background(), generateReq)
if err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -4,11 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"gorm.io/gorm"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service/oss"
"geekai/store"
@@ -95,35 +96,29 @@ func (s *Service) processNextTask() {
if err := s.ProcessTask(jobId); err != nil {
logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
s.UpdateJobStatus(jobId, types.JMTaskStatusFailed, err.Error())
} else {
logger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
}
}
// CreateTask 创建任务
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
func (s *Service) CreateTask(userId uint, req *types.JimengTaskRequest) (*model.JimengJob, error) {
// 生成任务ID
taskId := utils.RandString(20)
// 序列化任务参数
paramsJson, err := json.Marshal(req.Params)
if err != nil {
return nil, fmt.Errorf("marshal task params failed: %w", err)
}
// 创建任务记录
job := &model.JimengJob{
UserId: userId,
TaskId: taskId,
Type: req.Type,
ReqKey: req.ReqKey,
Prompt: req.Prompt,
TaskParams: string(paramsJson),
Status: model.JMTaskStatusInQueue,
Power: req.Power,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
UserId: userId,
TaskId: taskId,
Type: req.TaskType,
ReqKey: req.ReqKey,
Prompt: req.Prompt,
Params: utils.JsonEncode(req),
Status: types.JMTaskStatusInQueue,
Power: req.Power,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 保存到数据库
@@ -148,25 +143,71 @@ func (s *Service) ProcessTask(jobId uint) error {
}
// 更新任务状态为处理中
if err := s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, ""); err != nil {
if err := s.UpdateJobStatus(job.Id, types.JMTaskStatusGenerating, ""); err != nil {
return fmt.Errorf("update job status failed: %w", err)
}
// 解析任务参数
var req types.JimengTaskRequest
err := utils.JsonDecode(job.Params, &req)
if err != nil {
return fmt.Errorf("parse task params failed: %w", err)
}
// 构建请求并提交任务
req, err := s.buildTaskRequest(&job)
params, err := s.buildTaskRequest(&req)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
}
logger.Infof("提交即梦任务: %+v", req)
// 数字人任务,先识别主体
if req.TaskType == types.JMTaskTypeVirtualHuman {
if err := s.client.AvatarRecognition(req.ImageUrls[0], req.RecognizeKey); err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("avatar recognition failed: %v", err))
}
}
// 提交异步任务
resp, err := s.client.SubmitTask(req)
// 同步任务 ,后台执行
if req.ReqKey == DoubaoSeedream40ReqKey {
go func() {
resp, err := s.client.SubmitSyncImageTask(req)
if err != nil {
_ = s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
return
}
logger.Infof("同步任务提交成功: %+v", resp)
// 更新原始数据
rawData, _ := json.Marshal(resp)
updates := map[string]any{
"raw_data": string(rawData),
}
if resp.Error != nil {
updates["status"] = types.JMTaskStatusFailed
updates["err_msg"] = resp.Error.Message
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
return
}
// 更新任务状态
updates["status"] = types.JMTaskStatusSuccess
// 下载图片
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(*resp.Data[0].Url, ".png", false)
if err == nil {
updates["img_url"] = imgUrl
}
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
}()
return nil
}
logger.Debugf("提交即梦任务: %+v", params)
// 异步任务 ,前台执行
resp, err := s.client.SubmitTask(params)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
}
if resp.Code != 10000 {
if resp.Code != CodeSuccess {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
}
@@ -184,172 +225,51 @@ func (s *Service) ProcessTask(jobId uint) error {
}
// buildTaskRequest 构建任务请求(统一的参数解析)
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
// 解析任务参数
func (s *Service) buildTaskRequest(req *types.JimengTaskRequest) (map[string]any, error) {
var params map[string]any
if err := json.Unmarshal([]byte(job.TaskParams), &params); err != nil {
err := utils.JsonDecode(utils.JsonEncode(req), &params)
if err != nil {
return nil, fmt.Errorf("parse task params failed: %w", err)
}
// 构建基础请求
req := &SubmitTaskRequest{
ReqKey: job.ReqKey,
Prompt: job.Prompt,
}
// 根据任务类型设置特定参数
switch job.Type {
case model.JMTaskTypeTextToImage:
s.setTextToImageParams(req, params)
case model.JMTaskTypeImageToImage:
s.setImageToImageParams(req, params)
case model.JMTaskTypeImageEdit:
s.setImageEditParams(req, params)
case model.JMTaskTypeImageEffects:
s.setImageEffectsParams(req, params)
case model.JMTaskTypeTextToVideo:
s.setTextToVideoParams(req, params)
case model.JMTaskTypeImageToVideo:
s.setImageToVideoParams(req, params)
default:
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
}
return req, nil
}
// setTextToImageParams 设置文生图参数
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
if scale, ok := params["scale"]; ok {
if scaleVal, ok := scale.(float64); ok {
req.Scale = scaleVal
}
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
}
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
if usePreLlm, ok := params["use_pre_llm"]; ok {
if usePreLlmVal, ok := usePreLlm.(bool); ok {
req.UsePreLLM = usePreLlmVal
}
}
}
// setImageToImageParams 设置图生图参数
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
if imageInput, ok := params["image_input"].(string); ok {
req.ImageInput = imageInput
}
if gpen, ok := params["gpen"]; ok {
if gpenVal, ok := gpen.(float64); ok {
req.Gpen = gpenVal
}
}
if skin, ok := params["skin"]; ok {
if skinVal, ok := skin.(float64); ok {
req.Skin = skinVal
}
}
if skinUnifi, ok := params["skin_unifi"]; ok {
if skinUnifiVal, ok := skinUnifi.(float64); ok {
req.SkinUnifi = skinUnifiVal
}
}
if genMode, ok := params["gen_mode"].(string); ok {
req.GenMode = genMode
}
s.setCommonParams(req, params) // 复用通用参数
}
// setImageEditParams 设置图像编辑参数
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
if imageUrls, ok := params["image_urls"].([]any); ok {
for _, url := range imageUrls {
if urlStr, ok := url.(string); ok {
req.ImageUrls = append(req.ImageUrls, urlStr)
// 把 size 转成 width 和 height
if size, ok := params["size"]; ok {
if sizeStr, ok := size.(string); ok {
if strings.Contains(sizeStr, "x") {
sizes := strings.Split(sizeStr, "x")
params["width"] = sizes[0]
params["height"] = sizes[1]
}
}
delete(params, "size")
}
if binaryData, ok := params["binary_data_base64"].([]any); ok {
for _, data := range binaryData {
if dataStr, ok := data.(string); ok {
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
}
}
}
if scale, ok := params["scale"]; ok {
if scaleVal, ok := scale.(float64); ok {
req.Scale = scaleVal
}
}
s.setCommonParams(req, params)
}
// setImageEffectsParams 设置图像特效参数
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
if imageInput1, ok := params["image_input1"].(string); ok {
req.ImageInput1 = imageInput1
}
if templateId, ok := params["template_id"].(string); ok {
req.TemplateId = templateId
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
// duration 转成 frames
if duration, ok := params["duration"]; ok {
if secs, ok := duration.(int); ok {
params["frames"] = secs*24 + 1
}
delete(params, "duration")
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
}
// setTextToVideoParams 设置文生视频参数
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
req.AspectRatio = aspectRatio
// 单独处理图片特效任务
if req.ReqKey == ImageEffectReqKey {
params["image_input1"] = req.ImageUrls[0]
delete(params, "image_urls")
}
s.setCommonParams(req, params)
}
// setImageToVideoParams 设置图生视频参数
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
req.AspectRatio = aspectRatio
// 动作迁移,数字人任务参数处理
if req.TaskType == types.JMTaskTypeVirtualHuman || req.TaskType == types.JMTaskTypeActionTransfer {
params["image_url"] = req.ImageUrls[0]
delete(params, "image_urls")
}
if req.RecognizeKey != "" {
delete(params, "recognize_key")
}
}
// setCommonParams 设置通用参数seed, width, height等
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
}
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
// 删除多余参数,剩下的就是各个任务自己专有参数了
delete(params, "type")
delete(params, "power")
return params, nil
}
// pollTaskStatus 轮询任务状态
@@ -357,7 +277,7 @@ func (s *Service) pollTaskStatus() {
for {
var jobs []model.JimengJob
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
s.db.Where("status IN (?)", []types.JMTaskStatus{types.JMTaskStatusGenerating, types.JMTaskStatusInQueue}).Find(&jobs)
if len(jobs) == 0 {
logger.Debugf("no jimeng task to poll, sleep 10s")
time.Sleep(10 * time.Second)
@@ -371,12 +291,17 @@ func (s *Service) pollTaskStatus() {
continue
}
// 豆包生图 4.0 是同步任务,不需要轮询
if job.ReqKey == DoubaoSeedream40ReqKey {
continue
}
// 查询任务状态
resp, err := s.client.QueryTask(&QueryTaskRequest{
ReqKey: job.ReqKey,
TaskId: job.TaskId,
ReqJson: `{"return_url":true}`,
})
}, ASyncActionGetResult)
if err != nil {
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
@@ -387,13 +312,13 @@ func (s *Service) pollTaskStatus() {
rawData, _ := json.Marshal(resp)
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
if resp.Code != 10000 {
if resp.Code != CodeSuccess {
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
continue
}
switch resp.Data.Status {
case model.JMTaskStatusDone:
case types.JMTaskStatusDone:
// 判断任务是否成功
if resp.Message != "Success" {
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
@@ -402,7 +327,7 @@ func (s *Service) pollTaskStatus() {
// 任务完成,更新结果
updates := map[string]any{
"status": model.JMTaskStatusSuccess,
"status": types.JMTaskStatusSuccess,
"updated_at": time.Now(),
}
@@ -425,15 +350,15 @@ func (s *Service) pollTaskStatus() {
}
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
case types.JMTaskStatusInQueue, types.JMTaskStatusGenerating:
// 任务处理中
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
s.UpdateJobStatus(job.Id, types.JMTaskStatusGenerating, "")
case model.JMTaskStatusNotFound:
case types.JMTaskStatusNotFound:
// 任务未找到
s.handleTaskError(job.Id, "task not found")
case model.JMTaskStatusExpired:
case types.JMTaskStatusExpired:
continue
default:
logger.Warnf("unknown task status: %s", resp.Data.Status)
@@ -448,7 +373,7 @@ func (s *Service) pollTaskStatus() {
}
// UpdateJobStatus 更新任务状态
func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg string) error {
func (s *Service) UpdateJobStatus(jobId uint, status types.JMTaskStatus, errMsg string) error {
updates := map[string]any{
"status": status,
"updated_at": time.Now(),
@@ -462,7 +387,7 @@ func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg
// handleTaskError 处理任务错误
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
return s.UpdateJobStatus(jobId, types.JMTaskStatusFailed, errMsg)
}
// PushTaskToQueue 推送任务到队列(用于手动重试)
@@ -473,8 +398,8 @@ func (s *Service) PushTaskToQueue(jobId uint) error {
// GetTaskStats 获取任务统计信息
func (s *Service) GetTaskStats() (map[string]any, error) {
type StatResult struct {
Status string `json:"status"`
Count int64 `json:"count"`
Status types.JMTaskStatus `json:"status"`
Count int64 `json:"count"`
}
var stats []StatResult
@@ -496,7 +421,7 @@ func (s *Service) GetTaskStats() (map[string]any, error) {
for _, stat := range stats {
result["total"] = result["total"].(int64) + stat.Count
result[stat.Status] = stat.Count
result[string(stat.Status)] = stat.Count
}
return result, nil

View File

@@ -1,43 +1,9 @@
package jimeng
import "geekai/store/model"
// ReqKey 常量定义
const (
ReqKeyTextToImage = "high_aes_general_v30l_zt2i" // 文生图
ReqKeyImageToImagePortrait = "i2i_portrait_photo" // 图生图人像写真
ReqKeyImageEdit = "seededit_v3.0" // 图像编辑
ReqKeyImageEffects = "i2i_multi_style_zx2x" // 图像特效
ReqKeyTextToVideo = "jimeng_vgfm_t2v_l20" // 文生视频
ReqKeyImageToVideo = "jimeng_vgfm_i2v_l20" // 图生视频
import (
"geekai/core/types"
)
// SubmitTaskRequest 提交任务请求
type SubmitTaskRequest struct {
ReqKey string `json:"req_key"`
// 文生图参数
Prompt string `json:"prompt,omitempty"`
Seed int64 `json:"seed,omitempty"`
Scale float64 `json:"scale,omitempty"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
UsePreLLM bool `json:"use_pre_llm,omitempty"`
// 图生图参数
ImageInput string `json:"image_input,omitempty"`
ImageUrls []string `json:"image_urls,omitempty"`
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
Gpen float64 `json:"gpen,omitempty"`
Skin float64 `json:"skin,omitempty"`
SkinUnifi float64 `json:"skin_unifi,omitempty"`
GenMode string `json:"gen_mode,omitempty"`
// 图像编辑参数
// 图像特效参数
ImageInput1 string `json:"image_input1,omitempty"`
TemplateId string `json:"template_id,omitempty"`
// 视频生成参数
AspectRatio string `json:"aspect_ratio,omitempty"`
}
// SubmitTaskResponse 提交任务响应
type SubmitTaskResponse struct {
Code int `json:"code"`
@@ -73,7 +39,7 @@ type QueryTaskResponse struct {
ImageUrls []string `json:"image_urls"`
VideoUrl string `json:"video_url"`
RespData string `json:"resp_data"`
Status model.JMTaskStatus `json:"status"`
Status types.JMTaskStatus `json:"status"`
LlmResult string `json:"llm_result"`
PeResult string `json:"pe_result"`
PredictTagsResult string `json:"predict_tags_result"`
@@ -83,9 +49,73 @@ type QueryTaskResponse struct {
} `json:"data"`
}
const CodeSuccess = 10000
// 即梦AI错误代码常量
const (
// 成功
ECSuccess = 10000
// 请求参数错误 (50200-50215)
ECReqInvalidArgs = 50200 // 参数错误
ECReqMissingArgs = 50201 // 缺少参数
ECParseArgs = 50204 // 参数类型错误/参数缺失
ECImageSizeLimited = 50205 // 图像尺寸超过限制
ECImageEmpty = 50206 // 请求参数中没有获取到图像
ECImageDecodeError = 50207 // 图像解码错误
ECVideoEmpty = 50209 // 请求参数中没有获取到视频
ECVideoDecodeError = 50210 // 视频解码错误
ECVideoSizeLimited = 50211 // 视频尺寸超过限制
ECReqBodySizeLimited = 50213 // 请求Body过大
ECVideoTimeTooLong = 50214 // 输入视频时长过大
ECRPCProcess = 50215 // 请求处理失败
// 算法服务错误 (60102-60208)
ECJPFaceDetect = 60102 // 算法服务需要输入人脸图,但未检测到
ECFSLeaderRiskError = 60208 // 输入图片中包含敏感信息,未通过审核
// 权限和系统错误 (50400-50501)
ECAuth = 50400 // 权限校验失败
ECReqMethod = 50402 // 访问的接口不存在
ECReqLimit = 50429 // 超过调用QPS限制
ECInternal = 50500 // 服务器内部错误
ECRPCInternal = 50501 // 服务器内部RPC错误
)
// 错误代码到错误信息的映射
var errorCodeMessages = map[int]string{
// 成功
ECSuccess: "请求成功",
// 请求参数错误
ECReqInvalidArgs: "参数错误检查入参及MIME类型",
ECReqMissingArgs: "缺少参数检查入参及MIME类型",
ECParseArgs: "参数类型错误/参数缺失检查入参及MIME类型",
ECImageSizeLimited: "图像尺寸超过限制,参考接口文档入参要求部分",
ECImageEmpty: "请求参数中没有获取到图像,检查入参",
ECImageDecodeError: "图像解码错误没有获取到图像或者通过image_base64参数传递图像是base64解码错误检查输出图片或检查base64是否错误携带前缀",
ECVideoEmpty: "请求参数中没有获取到视频。输入为视频时可能返回此错误,检查入参",
ECVideoDecodeError: "视频解码错误。输入为视频时可能返回此错误,检查输入视频是否不正确",
ECVideoSizeLimited: "视频尺寸超过限制。输入为视频时可能返回此错误,检查输入视频大小",
ECReqBodySizeLimited: "请求Body过大超出接口限制检查请求Body大小",
ECVideoTimeTooLong: "输入视频时长过大,检查输入视频时长",
ECRPCProcess: "由于输入的图片、视频、参数等不满足要求,导致请求处理失败。若接口文档中有具体说明,优先参考其具体含义,按照具体服务说明进行检查",
// 算法服务错误
ECJPFaceDetect: "算法服务需要输入人脸图,但未检测到,检查输入图片是否包含人脸",
ECFSLeaderRiskError: "输入图片中包含敏感信息,未通过审核",
// 权限和系统错误
ECAuth: "权限校验失败,请检查是否已创建应用并开通服务或签名,参考接入指南及快速接入",
ECReqMethod: "访问的接口不存在,检查入参",
ECReqLimit: "超过调用QPS限制购买QPS增项包",
ECInternal: "服务器内部错误,提工单",
ECRPCInternal: "服务器内部RPC错误提工单",
}
// CreateTaskRequest 创建任务请求
type CreateTaskRequest struct {
Type model.JMTaskType `json:"type"`
Type types.JMTaskType `json:"type"`
Prompt string `json:"prompt"`
Params map[string]any `json:"params"`
ReqKey string `json:"req_key"`
@@ -93,53 +123,14 @@ type CreateTaskRequest struct {
Power int `json:"power,omitempty"`
}
// LogoInfo 水印信息
type LogoInfo struct {
AddLogo bool `json:"add_logo"`
Position int `json:"position"`
Language int `json:"language"`
Opacity float64 `json:"opacity"`
LogoTextContent string `json:"logo_text_content"`
}
// ReqJsonConfig 查询配置
type ReqJsonConfig struct {
ReturnUrl bool `json:"return_url"`
LogoInfo *LogoInfo `json:"logo_info,omitempty"`
}
// ImageEffectTemplate 图像特效模板
const (
TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格
TemplateIdMyWorld = "my_world" // 像素世界风
TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版
TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风
TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版
TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风
TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风
TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版
TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰
TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣
TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡
TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办
TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里
TemplateIdGlassBall = "glass_ball" // 玻璃球
ImageEffectReqKey = "i2i_multi_style_zx2x"
DoubaoSeedream40ReqKey = "doubao-seedream-4-0-250828"
)
// AspectRatio 视频宽高比
const (
AspectRatio16_9 = "16:9" // 1280×720
AspectRatio9_16 = "9:16" // 720×1280
AspectRatio1_1 = "1:1" // 960×960
AspectRatio4_3 = "4:3" // 960×720
AspectRatio3_4 = "3:4" // 720×960
AspectRatio21_9 = "21:9" // 1680×720
AspectRatio9_21 = "9:21" // 720×1680
)
// GenMode 生成模式
const (
GenModeCreative = "creative" // 提示词模式
GenModeReference = "reference" // 全参考模式
GenModeReferenceChar = "reference_char" // 人物参考模式
ASyncActionSubmit = "CVSync2AsyncSubmitTask" // 异步提交任务
SyncActionSubmit = "CVSubmitTask" // 同步提交任务
ASyncActionGetResult = "CVSync2AsyncGetResult" // 异步获取结果
SyncActionGetResult = "CVGetResult" // 同步获取结果
)

View File

@@ -159,8 +159,16 @@ func (s *MigrationService) MigrateConfigContent() error {
// 数据表迁移
func (s *MigrationService) TableMigration() {
// v4.2.7 数据表迁移
if s.db.Migrator().HasColumn(&model.JimengJob{}, "task_params") {
s.db.Migrator().RenameColumn(&model.JimengJob{}, "task_params", "params")
}
// 新数据表
s.db.AutoMigrate(&model.Moderation{})
if !s.db.Migrator().HasTable(&model.Moderation{}) {
s.db.AutoMigrate(&model.Moderation{})
}
// 订单字段整理
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {

View File

@@ -57,13 +57,19 @@ func (s *UserService) DecreasePower(userId uint, power int, log model.PowerLog)
defer s.lock.Unlock()
tx := s.db.Begin()
var user model.User
tx.Where("id", userId).First(&user)
if user.Power < power {
tx.Rollback()
return fmt.Errorf("用户算力不足")
}
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("扣减算力失败:%v", err)
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,

View File

@@ -1,54 +1,30 @@
package model
import (
"geekai/core/types"
"time"
)
// JimengJob 即梦AI任务模型
type JimengJob struct {
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
UserId uint `gorm:"column:user_id;type:int(11);not null;index;comment:用户ID" json:"user_id"`
TaskId string `gorm:"column:task_id;type:varchar(100);not null;index;comment:任务ID" json:"task_id"`
Type JMTaskType `gorm:"column:type;type:varchar(50);not null;comment:任务类型" json:"type"`
ReqKey string `gorm:"column:req_key;type:varchar(100);comment:请求Key" json:"req_key"`
Prompt string `gorm:"column:prompt;type:text;comment:提示词" json:"prompt"`
TaskParams string `gorm:"column:task_params;type:text;comment:任务参数JSON" json:"task_params"`
ImgURL string `gorm:"column:img_url;type:varchar(1024);comment:图片或封面URL" json:"img_url"`
VideoURL string `gorm:"column:video_url;type:varchar(1024);comment:视频URL" json:"video_url"`
RawData string `gorm:"column:raw_data;type:text;comment:原始API响应" json:"raw_data"`
Progress int `gorm:"column:progress;type:int;default:0;comment:进度百分比" json:"progress"`
Status JMTaskStatus `gorm:"column:status;type:varchar(20);default:'pending';comment:任务状态" json:"status"`
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
Power int `gorm:"column:power;type:int(11);default:0;comment:消耗算力" json:"power"`
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;comment:更新时间" json:"updated_at"`
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
UserId uint `gorm:"column:user_id;type:int(11);not null;index;comment:用户ID" json:"user_id"`
TaskId string `gorm:"column:task_id;type:varchar(100);not null;index;comment:任务ID" json:"task_id"`
Type types.JMTaskType `gorm:"column:type;type:varchar(50);not null;comment:任务类型" json:"type"`
ReqKey string `gorm:"column:req_key;type:varchar(100);comment:请求Key" json:"req_key"`
Prompt string `gorm:"column:prompt;type:text;comment:提示词" json:"prompt"`
Params string `gorm:"column:params;type:text;comment:任务参数JSON" json:"params"`
ImgURL string `gorm:"column:img_url;type:varchar(1024);comment:图片或封面URL" json:"img_url"`
VideoURL string `gorm:"column:video_url;type:varchar(1024);comment:视频URL" json:"video_url"`
RawData string `gorm:"column:raw_data;type:text;comment:原始API响应" json:"raw_data"`
Progress int `gorm:"column:progress;type:int;default:0;comment:进度百分比" json:"progress"`
Status types.JMTaskStatus `gorm:"column:status;type:varchar(20);default:'pending';comment:任务状态" json:"status"`
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
Power int `gorm:"column:power;type:int(11);default:0;comment:消耗算力" json:"power"`
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;comment:更新时间" json:"updated_at"`
}
// JMTaskStatus 任务状态
type JMTaskStatus string
const (
JMTaskStatusInQueue = JMTaskStatus("in_queue") // 任务已提交
JMTaskStatusGenerating = JMTaskStatus("generating") // 任务处理中
JMTaskStatusDone = JMTaskStatus("done") // 处理完成
JMTaskStatusNotFound = JMTaskStatus("not_found") // 任务未找到
JMTaskStatusSuccess = JMTaskStatus("success") // 任务成功
JMTaskStatusFailed = JMTaskStatus("failed") // 任务失败
JMTaskStatusExpired = JMTaskStatus("expired") // 任务过期
)
// JMTaskType 任务类型
type JMTaskType string
const (
JMTaskTypeTextToImage = JMTaskType("text_to_image") // 文生图
JMTaskTypeImageToImage = JMTaskType("image_to_image") // 图生图
JMTaskTypeImageEdit = JMTaskType("image_edit") // 图像编辑
JMTaskTypeImageEffects = JMTaskType("image_effects") // 图像特效
JMTaskTypeTextToVideo = JMTaskType("text_to_video") // 文生视频
JMTaskTypeImageToVideo = JMTaskType("image_to_video") // 图生视频
)
// TableName 返回数据表名称
func (JimengJob) TableName() string {
return "geekai_jimeng_jobs"

View File

@@ -1,23 +1,23 @@
package vo
import "geekai/store/model"
import "geekai/core/types"
// JimengJob 即梦AI任务VO
type JimengJob struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
TaskId string `json:"task_id"`
Type model.JMTaskType `json:"type"`
ReqKey string `json:"req_key"`
Prompt string `json:"prompt"`
TaskParams string `json:"task_params"`
ImgURL string `json:"img_url"`
VideoURL string `json:"video_url"`
RawData string `json:"raw_data"`
Progress int `json:"progress"`
Status model.JMTaskStatus `json:"status"`
ErrMsg string `json:"err_msg"`
Power int `json:"power"`
CreatedAt int64 `json:"created_at"` // 时间戳
UpdatedAt int64 `json:"updated_at"` // 时间戳
Id uint `json:"id"`
UserId uint `json:"user_id"`
TaskId string `json:"task_id"`
Type types.JMTaskType `json:"type"`
ReqKey string `json:"req_key"`
Prompt string `json:"prompt"`
Params map[string]any `json:"params"`
ImgURL string `json:"img_url"`
VideoURL string `json:"video_url"`
RawData string `json:"raw_data"`
Progress int `json:"progress"`
Status types.JMTaskStatus `json:"status"`
ErrMsg string `json:"err_msg"`
Power int `json:"power"`
CreatedAt int64 `json:"created_at"` // 时间戳
UpdatedAt int64 `json:"updated_at"` // 时间戳
}

10
api/test/test_test.go Normal file
View File

@@ -0,0 +1,10 @@
package test
import (
"fmt"
"testing"
)
func Test(t *testing.T) {
fmt.Println("test")
}

817
api/utils/media_duration.go Normal file
View File

@@ -0,0 +1,817 @@
package utils
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net/http"
"os"
"time"
)
// AudioDuration returns duration of an audio file.
// Supported formats: MP3, WAV (auto-detected by header)
func AudioDuration(path string) (time.Duration, error) {
f, err := os.Open(path)
if err != nil {
return 0, err
}
defer f.Close()
// Peek first 12 bytes to detect format
head := make([]byte, 12)
n, err := io.ReadFull(f, head)
if err != nil {
return 0, err
}
if n < 12 {
return 0, errors.New("file too small")
}
// WAV: RIFF....WAVE
if string(head[0:4]) == "RIFF" && string(head[8:12]) == "WAVE" {
if _, err := f.Seek(0, io.SeekStart); err != nil {
return 0, err
}
return wavDuration(f)
}
// MP3 can start with ID3 or frame sync 0xFFEx
if string(head[0:3]) == "ID3" || (head[0] == 0xFF && (head[1]&0xE0) == 0xE0) {
if _, err := f.Seek(0, io.SeekStart); err != nil {
return 0, err
}
return mp3Duration(f)
}
return 0, errors.New("unsupported audio format")
}
// AudioDurationFromURL downloads the url to a temp file and returns duration.
func AudioDurationFromURL(url string) (time.Duration, error) {
path, err := fetchURLToTemp(url, 30*time.Second)
if err != nil {
return 0, err
}
defer os.Remove(path)
return AudioDuration(path)
}
// VideoDurationMP4 returns duration of an MP4 file (MOV/MP4 base media).
func VideoDurationMP4(path string) (time.Duration, error) {
f, err := os.Open(path)
if err != nil {
return 0, err
}
defer f.Close()
return mp4Duration(f)
}
// VideoDurationMP4FromURL downloads the url to a temp file and returns duration.
func VideoDurationMP4FromURL(url string) (time.Duration, error) {
path, err := fetchURLToTemp(url, 30*time.Second)
if err != nil {
return 0, err
}
defer os.Remove(path)
return VideoDurationMP4(path)
}
// ---------------------- helpers ----------------------
func fetchURLToTemp(url string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("http status: %d", resp.StatusCode)
}
tmp, err := os.CreateTemp("", "media-*")
if err != nil {
return "", err
}
defer tmp.Close()
if _, err := io.Copy(tmp, resp.Body); err != nil {
path := tmp.Name()
_ = os.Remove(path)
return "", err
}
return tmp.Name(), nil
}
// ---------------------- WAV ----------------------
func wavDuration(r io.ReadSeeker) (time.Duration, error) {
// RIFF header already checked outside if needed. We parse chunks to get fmt and data.
// WAV little-endian
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, err
}
// Read RIFF header (12 bytes)
head := make([]byte, 12)
if _, err := io.ReadFull(r, head); err != nil {
return 0, err
}
if string(head[0:4]) != "RIFF" || string(head[8:12]) != "WAVE" {
return 0, errors.New("invalid wav header")
}
var sampleRate uint32
var numChans uint16
var bitsPerSample uint16
var byteRate uint32
var dataSize uint32
for {
chunkHdr := make([]byte, 8)
if _, err := io.ReadFull(r, chunkHdr); err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
return 0, err
}
ckID := string(chunkHdr[0:4])
ckSize := binary.LittleEndian.Uint32(chunkHdr[4:8])
switch ckID {
case "fmt ":
fmtData := make([]byte, ckSize)
if _, err := io.ReadFull(r, fmtData); err != nil {
return 0, err
}
// audioFormat := binary.LittleEndian.Uint16(fmtData[0:2]) // 1 = PCM
numChans = binary.LittleEndian.Uint16(fmtData[2:4])
sampleRate = binary.LittleEndian.Uint32(fmtData[4:8])
byteRate = binary.LittleEndian.Uint32(fmtData[8:12])
// blockAlign := binary.LittleEndian.Uint16(fmtData[12:14])
if len(fmtData) >= 16 {
bitsPerSample = binary.LittleEndian.Uint16(fmtData[14:16])
}
case "data":
dataSize = ckSize
// Skip data content
if _, err := r.Seek(int64(ckSize), io.SeekCurrent); err != nil {
return 0, err
}
default:
// Skip other chunks
if _, err := r.Seek(int64(ckSize), io.SeekCurrent); err != nil {
return 0, err
}
}
// Chunks are word-aligned (pad byte if odd size)
if ckSize%2 == 1 {
if _, err := r.Seek(1, io.SeekCurrent); err != nil { // skip pad
return 0, err
}
}
}
if sampleRate == 0 || numChans == 0 {
return 0, errors.New("invalid wav fmt")
}
var durationSeconds float64
if byteRate != 0 {
durationSeconds = float64(dataSize) / float64(byteRate)
} else {
bytesPerSec := float64(sampleRate) * float64(numChans) * float64(bitsPerSample) / 8.0
if bytesPerSec == 0 {
return 0, errors.New("invalid wav parameters")
}
durationSeconds = float64(dataSize) / bytesPerSec
}
return time.Duration(durationSeconds * float64(time.Second)), nil
}
// ---------------------- MP3 ----------------------
func mp3Duration(r io.ReadSeeker) (time.Duration, error) {
// Strategy:
// 1) Skip ID3v2 header if present.
// 2) Try read first frame and detect XING/Info or VBRI to get total frames and duration.
// 3) If VBR headers not present, fall back to CBR estimation: (audioDataBytes * 8) / bitrate.
// File size
fi, err := fileSizeFromSeeker(r)
if err != nil {
return 0, err
}
// Skip ID3v2
var id3v2Size int64
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, err
}
id3v2Size, err = skipID3v2(r)
if err != nil {
return 0, err
}
// Remember audio start offset
startOffset, _ := r.Seek(0, io.SeekCurrent)
// Read first frame header (search sync)
off, fh, err := findNextMP3Frame(r)
if err != nil {
return 0, err
}
if _, err := r.Seek(off, io.SeekStart); err != nil {
return 0, err
}
// Check for XING/Info header in first frame (for VBR)
totalFrames, sampleRate, samplesPerFrame, bitrateKbps, vbrFound, err := parseFirstFrameForVBR(r, fh)
if err != nil {
return 0, err
}
if vbrFound && totalFrames > 0 && sampleRate > 0 && samplesPerFrame > 0 {
seconds := (float64(totalFrames) * float64(samplesPerFrame)) / float64(sampleRate)
return time.Duration(seconds * float64(time.Second)), nil
}
// Fall back to CBR estimate using bitrate and data size (excluding ID3v2 and ID3v1)
// Detect ID3v1 at end (128 bytes TAG)
var id3v1Size int64
if fi >= 128 {
if _, err := r.Seek(fi-128, io.SeekStart); err == nil {
buf := make([]byte, 3)
if _, err := io.ReadFull(r, buf); err == nil {
if string(buf) == "TAG" {
id3v1Size = 128
}
}
}
}
audioBytes := fi - id3v2Size - id3v1Size - startOffset
if audioBytes <= 0 || bitrateKbps == 0 {
return 0, errors.New("unable to estimate mp3 duration")
}
// bitrateKbps in kbps, bytes -> bits
seconds := float64(audioBytes*8) / float64(bitrateKbps*1000)
return time.Duration(seconds * float64(time.Second)), nil
}
type mp3FrameHeader struct {
Version int // 1: MPEG1, 2: MPEG2, 25: MPEG2.5
Layer int // 1,2,3
BitrateKbps int
SampleRate int
Padding int
ChannelMode int // 0:Stereo,1:Joint,2:Dual,3:Mono
}
func findNextMP3Frame(r io.ReadSeeker) (int64, mp3FrameHeader, error) {
var hdr mp3FrameHeader
// Start from current pos and scan up to 64KB
start, _ := r.Seek(0, io.SeekCurrent)
limit := int64(64 * 1024)
buf := make([]byte, limit)
n, err := r.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
return 0, hdr, err
}
for i := 0; i+4 <= n; i++ {
if buf[i] == 0xFF && (buf[i+1]&0xE0) == 0xE0 { // sync
if h, ok := parseMP3Header(buf[i : i+4]); ok {
offset := start + int64(i)
return offset, h, nil
}
}
}
return 0, hdr, errors.New("mp3 frame not found")
}
func parseMP3Header(b []byte) (mp3FrameHeader, bool) {
var h mp3FrameHeader
if len(b) < 4 {
return h, false
}
if b[0] != 0xFF || (b[1]&0xE0) != 0xE0 {
return h, false
}
versionBits := (b[1] >> 3) & 0x03
layerBits := (b[1] >> 1) & 0x03
bitrateBits := (b[2] >> 4) & 0x0F
sampleRateBits := (b[2] >> 2) & 0x03
paddingBit := (b[2] >> 1) & 0x01
channelMode := (b[3] >> 6) & 0x03
var version int
switch versionBits {
case 0x00:
version = 25 // MPEG 2.5
case 0x02:
version = 2 // MPEG 2
case 0x03:
version = 1 // MPEG 1
default:
return h, false
}
var layer int
switch layerBits {
case 0x01:
layer = 3
case 0x02:
layer = 2
case 0x03:
layer = 1
default:
return h, false
}
br := mp3BitrateKbps(version, layer, int(bitrateBits))
if br == 0 {
return h, false
}
sr := mp3SampleRate(version, int(sampleRateBits))
if sr == 0 {
return h, false
}
h = mp3FrameHeader{
Version: version,
Layer: layer,
BitrateKbps: br,
SampleRate: sr,
Padding: int(paddingBit),
ChannelMode: int(channelMode),
}
return h, true
}
func mp3BitrateKbps(version, layer, index int) int {
// index: 1..14 valid; 0,15 invalid
if index <= 0 || index == 15 {
return 0
}
// Tables per ISO/IEC 11172-3/13818-3 (common subset)
var tbl [15]int
if layer == 1 { // Layer I
if version == 1 { // MPEG1
tbl = [15]int{0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448}
} else { // MPEG2/2.5
tbl = [15]int{0, 32, 48, 56, 64, 80, 96, 112, 128, 144, 160, 176, 192, 224, 256}
}
} else if layer == 2 { // Layer II
if version == 1 {
tbl = [15]int{0, 32, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384}
} else {
tbl = [15]int{0, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160}
}
} else { // Layer III
if version == 1 {
tbl = [15]int{0, 32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320}
} else {
tbl = [15]int{0, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160}
}
}
return tbl[index]
}
func mp3SampleRate(version, index int) int {
if index == 3 {
return 0
}
// base table for MPEG1
base := [3]int{44100, 48000, 32000}
sr := base[index]
if version == 2 { // MPEG2
sr /= 2
} else if version == 25 { // MPEG2.5
sr /= 4
}
return sr
}
func samplesPerMP3Frame(version, layer int) int {
switch layer {
case 1:
return 384
case 2:
return 1152
case 3:
if version == 1 {
return 1152
}
return 576 // MPEG2/2.5 Layer III
default:
return 0
}
}
func parseFirstFrameForVBR(r io.ReadSeeker, fh mp3FrameHeader) (totalFrames uint32, sampleRate int, samplesPerFrame int, bitrateKbps int, vbrFound bool, err error) {
// After the 4-byte header, possible side info and then XING/Info
if _, err = r.Seek(0, io.SeekCurrent); err != nil {
return
}
// Re-read header
hdr := make([]byte, 4)
if _, err = io.ReadFull(r, hdr); err != nil {
return
}
// side info size depends on MPEG version and channel mode (for Layer III)
sideInfoSize := 0
if fh.Layer == 3 { // Layer III
if fh.Version == 1 { // MPEG1
if fh.ChannelMode == 3 { // mono
sideInfoSize = 17
} else {
sideInfoSize = 32
}
} else { // MPEG2/2.5
if fh.ChannelMode == 3 {
sideInfoSize = 9
} else {
sideInfoSize = 17
}
}
}
// Read next up to 120 bytes to search for XING/Info or VBRI
buf := make([]byte, sideInfoSize+120)
if _, err = io.ReadFull(r, buf); err != nil {
// If short, still try within available
if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
return
}
}
// Search XING/Info signature
sigs := [][]byte{[]byte("Xing"), []byte("Info")}
for _, sig := range sigs {
idx := indexOf(buf, sig)
if idx >= 0 {
// flags after signature (4 bytes), then if frames flag set, 4 bytes frames
if len(buf) >= idx+4+4 {
flags := binary.BigEndian.Uint32(buf[idx+4 : idx+8])
var frames uint32
if (flags & 0x01) != 0 { // frames present
if len(buf) >= idx+8+4 {
frames = binary.BigEndian.Uint32(buf[idx+8 : idx+12])
}
}
if frames > 0 {
vbrFound = true
totalFrames = frames
sampleRate = fh.SampleRate
samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer)
bitrateKbps = fh.BitrateKbps
return
}
}
}
}
// Check VBRI (usually at 32 bytes after header for MPEG1 Layer III)
if len(buf) >= 4 {
idx := indexOf(buf, []byte("VBRI"))
if idx >= 0 {
if len(buf) >= idx+4+2+2+4+4 {
// VBRI layout: 'VBRI'(4) + version(2) + delay(2) + quality(2?) varies; but at offset 10 comes bytes: bytes (4), frames (4)
// Some docs: offset 10: bytes, offset 14: frames (big-endian)
bytesOffset := idx + 10
framesOffset := idx + 14
if len(buf) >= framesOffset+4 {
frames := binary.BigEndian.Uint32(buf[framesOffset : framesOffset+4])
if frames > 0 {
vbrFound = true
totalFrames = frames
sampleRate = fh.SampleRate
samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer)
bitrateKbps = fh.BitrateKbps
_ = bytesOffset // not used
return
}
}
}
}
}
// No VBR header. Provide header info for CBR fallback
sampleRate = fh.SampleRate
samplesPerFrame = samplesPerMP3Frame(fh.Version, fh.Layer)
bitrateKbps = fh.BitrateKbps
return
}
func indexOf(haystack []byte, needle []byte) int {
for i := 0; i+len(needle) <= len(haystack); i++ {
match := true
for j := 0; j < len(needle); j++ {
if haystack[i+j] != needle[j] {
match = false
break
}
}
if match {
return i
}
}
return -1
}
func skipID3v2(r io.ReadSeeker) (int64, error) {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, err
}
head := make([]byte, 10)
if _, err := io.ReadFull(r, head); err != nil {
return 0, nil // no header
}
if string(head[0:3]) != "ID3" {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, err
}
return 0, nil
}
// size: 4 synchsafe bytes
sz := int64((int(head[6]&0x7F) << 21) | (int(head[7]&0x7F) << 14) | (int(head[8]&0x7F) << 7) | int(head[9]&0x7F))
// total header size = 10 + sz (+ footer 10 if flag set)
footer := int64(0)
if (head[5] & 0x10) != 0 { // footer present
footer = 10
}
total := 10 + sz + footer
if _, err := r.Seek(total, io.SeekStart); err != nil {
return 0, err
}
return total, nil
}
func fileSizeFromSeeker(r io.ReadSeeker) (int64, error) {
cur, err := r.Seek(0, io.SeekCurrent)
if err != nil {
return 0, err
}
end, err := r.Seek(0, io.SeekEnd)
if err != nil {
return 0, err
}
if _, err := r.Seek(cur, io.SeekStart); err != nil {
return 0, err
}
return end, nil
}
// ---------------------- MP4 ----------------------
type mp4BoxHeader struct {
Size uint64
Type [4]byte
}
func readBoxHeader(r io.ReadSeeker) (mp4BoxHeader, error) {
var h mp4BoxHeader
buf := make([]byte, 8)
if _, err := io.ReadFull(r, buf); err != nil {
return h, err
}
sz := binary.BigEndian.Uint32(buf[0:4])
copy(h.Type[:], buf[4:8])
if sz == 1 {
// 64-bit size follows
ext := make([]byte, 8)
if _, err := io.ReadFull(r, ext); err != nil {
return h, err
}
h.Size = binary.BigEndian.Uint64(ext)
} else {
h.Size = uint64(sz)
}
return h, nil
}
func skipBox(r io.ReadSeeker, boxSize uint64, alreadyRead int64) error {
toSkip := int64(boxSize) - alreadyRead
if toSkip < 0 {
return fmt.Errorf("invalid box size")
}
_, err := r.Seek(toSkip, io.SeekCurrent)
return err
}
func mp4Duration(r io.ReadSeeker) (time.Duration, error) {
if _, err := r.Seek(0, io.SeekStart); err != nil {
return 0, err
}
// Find moov box
var moovStart int64
var moovSize uint64
for {
pos, _ := r.Seek(0, io.SeekCurrent)
h, err := readBoxHeader(r)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
return 0, err
}
if string(h.Type[:]) == "moov" {
moovStart = pos
moovSize = h.Size
break
}
if h.Size < 8 {
return 0, errors.New("invalid mp4 box size")
}
if err := skipBox(r, h.Size, 8); err != nil {
return 0, err
}
}
if moovStart == 0 && moovSize == 0 {
return 0, errors.New("moov not found")
}
// Parse inside moov for video trak mdhd, else mvhd
if _, err := r.Seek(moovStart+8, io.SeekStart); err != nil { // skip moov header
return 0, err
}
end := moovStart + int64(moovSize)
var movieTimescale uint32
var movieDuration uint64
var foundVideoMdhd bool
var mdhdTimescale uint32
var mdhdDuration uint64
for {
pos, _ := r.Seek(0, io.SeekCurrent)
if pos >= end {
break
}
h, err := readBoxHeader(r)
if err != nil {
return 0, err
}
switch string(h.Type[:]) {
case "mvhd":
// movie header
ver := make([]byte, 1)
if _, err := io.ReadFull(r, ver); err != nil {
return 0, err
}
if _, err := r.Seek(3, io.SeekCurrent); err != nil { // flags
return 0, err
}
if ver[0] == 1 {
// version 1: 64-bit duration
buf := make([]byte, 8+8+4+8) // ctime(8) mtime(8) timescale(4) duration(8)
if _, err := io.ReadFull(r, buf); err != nil {
return 0, err
}
movieTimescale = binary.BigEndian.Uint32(buf[16:20])
movieDuration = binary.BigEndian.Uint64(buf[20:28])
} else {
buf := make([]byte, 4+4+4+4) // ctime mtime timescale duration
if _, err := io.ReadFull(r, buf); err != nil {
return 0, err
}
movieTimescale = binary.BigEndian.Uint32(buf[8:12])
movieDuration = uint64(binary.BigEndian.Uint32(buf[12:16]))
}
// skip rest of mvhd
read := int64(1 + 3)
if ver[0] == 1 {
read += int64(8 + 8 + 4 + 8)
} else {
read += int64(4 + 4 + 4 + 4)
}
if err := skipBox(r, h.Size, 8+read); err != nil {
return 0, err
}
case "trak":
// parse trak for hdlr and mdhd
tEnd := int64(0)
if h.Size < 8 {
return 0, errors.New("invalid trak size")
}
tEnd = pos + int64(h.Size)
var isVideo bool
var tMdhdTimescale uint32
var tMdhdDuration uint64
for {
cpos, _ := r.Seek(0, io.SeekCurrent)
if cpos >= tEnd {
break
}
ch, err := readBoxHeader(r)
if err != nil {
return 0, err
}
switch string(ch.Type[:]) {
case "mdia":
mEnd := cpos + int64(ch.Size)
for {
mpos, _ := r.Seek(0, io.SeekCurrent)
if mpos >= mEnd {
break
}
mh, err := readBoxHeader(r)
if err != nil {
return 0, err
}
switch string(mh.Type[:]) {
case "hdlr":
// skip version+flags (4), pre_defined(4)
if _, err := r.Seek(8, io.SeekCurrent); err != nil {
return 0, err
}
handler := make([]byte, 4)
if _, err := io.ReadFull(r, handler); err != nil {
return 0, err
}
if string(handler) == "vide" {
isVideo = true
}
if err := skipBox(r, mh.Size, 8+8+4); err != nil { // header + skipped + read handler
return 0, err
}
case "mdhd":
ver := make([]byte, 1)
if _, err := io.ReadFull(r, ver); err != nil {
return 0, err
}
if _, err := r.Seek(3, io.SeekCurrent); err != nil { // flags
return 0, err
}
if ver[0] == 1 {
buf := make([]byte, 8+8+4+8)
if _, err := io.ReadFull(r, buf); err != nil {
return 0, err
}
tMdhdTimescale = binary.BigEndian.Uint32(buf[16:20])
tMdhdDuration = binary.BigEndian.Uint64(buf[20:28])
} else {
buf := make([]byte, 4+4+4+4)
if _, err := io.ReadFull(r, buf); err != nil {
return 0, err
}
tMdhdTimescale = binary.BigEndian.Uint32(buf[8:12])
tMdhdDuration = uint64(binary.BigEndian.Uint32(buf[12:16]))
}
if err := skipBox(r, mh.Size, 8+1+3+int64(lenVersionPayload(ver[0]))); err != nil {
return 0, err
}
default:
if err := skipBox(r, mh.Size, 8); err != nil {
return 0, err
}
}
}
default:
if err := skipBox(r, ch.Size, 8); err != nil {
return 0, err
}
}
}
if isVideo && tMdhdTimescale != 0 && tMdhdDuration != 0 {
foundVideoMdhd = true
mdhdTimescale = tMdhdTimescale
mdhdDuration = tMdhdDuration
}
// Skip remaining of trak if any
if _, err := r.Seek(tEnd, io.SeekStart); err != nil {
return 0, err
}
default:
if err := skipBox(r, h.Size, 8); err != nil {
return 0, err
}
}
}
if foundVideoMdhd && mdhdTimescale != 0 {
sec := float64(mdhdDuration) / float64(mdhdTimescale)
return time.Duration(sec * float64(time.Second)), nil
}
if movieTimescale != 0 {
sec := float64(movieDuration) / float64(movieTimescale)
return time.Duration(sec * float64(time.Second)), nil
}
return 0, errors.New("failed to read mp4 duration")
}
func lenVersionPayload(ver byte) int {
if ver == 1 {
return 8 + 8 + 4 + 8
}
return 4 + 4 + 4 + 4
}