mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-05-10 19:54:25 +08:00
Merge tag 'v4.2.7'
This commit is contained in:
@@ -1,15 +1,21 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
||||
)
|
||||
|
||||
// Client 即梦API客户端
|
||||
@@ -50,6 +56,22 @@ func (c *Client) UpdateConfig(config types.JimengConfig) error {
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVSubmitTask": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVSubmitTask"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVGetResult": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVGetResult"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVProcess": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
@@ -71,6 +93,22 @@ func (c *Client) UpdateConfig(config types.JimengConfig) error {
|
||||
return c.testConnection()
|
||||
}
|
||||
|
||||
// GetErrorMessage 根据错误代码获取对应的错误信息
|
||||
func GetErrorMessage(code int) string {
|
||||
if message, exists := errorCodeMessages[code]; exists {
|
||||
return message
|
||||
}
|
||||
return fmt.Sprintf("未知错误代码: %d", code)
|
||||
}
|
||||
|
||||
// HandleResponseError 处理响应错误,根据错误代码返回详细的错误信息
|
||||
func HandleResponseError(code int, message string) error {
|
||||
if code == ECSuccess {
|
||||
return nil
|
||||
}
|
||||
return errors.New(GetErrorMessage(code))
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (c *Client) testConnection() error {
|
||||
|
||||
@@ -80,7 +118,7 @@ func (c *Client) testConnection() error {
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := c.QueryTask(testReq)
|
||||
_, err := c.QueryTask(testReq, ASyncActionGetResult)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
@@ -94,7 +132,7 @@ func (c *Client) testConnection() error {
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
|
||||
func (c *Client) SubmitTask(req map[string]any) (*SubmitTaskResponse, error) {
|
||||
// 直接将请求转为map[string]interface{}
|
||||
reqBodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
@@ -103,9 +141,14 @@ func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error)
|
||||
|
||||
// 直接使用序列化后的字节
|
||||
jsonBody := reqBodyBytes
|
||||
action := ASyncActionSubmit
|
||||
if v, ok := req["action"]; ok {
|
||||
action = v.(string)
|
||||
delete(req, "action")
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncSubmitTask", nil, string(jsonBody))
|
||||
respBody, statusCode, err := c.visual.Client.Json(action, nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
@@ -118,11 +161,70 @@ func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error)
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
// 检查响应错误代码
|
||||
if err := HandleResponseError(result.Code, result.Message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// 识别数字人主体
|
||||
func (c *Client) AvatarRecognition(imgUrl string, reqKey string) error {
|
||||
params := map[string]any{
|
||||
"image_url": imgUrl,
|
||||
"req_key": reqKey,
|
||||
}
|
||||
reqBodyBytes, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json(SyncActionSubmit, nil, string(reqBodyBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result SubmitTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
// 检查响应错误代码
|
||||
if err := HandleResponseError(result.Code, result.Message); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 等待任务完成
|
||||
for {
|
||||
resp, err := c.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: reqKey,
|
||||
TaskId: result.Data.TaskId,
|
||||
}, SyncActionGetResult)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query task failed: %w", err)
|
||||
}
|
||||
if resp.Data.Status != types.JMTaskStatusDone {
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
var respData map[string]int
|
||||
if err := json.Unmarshal([]byte(resp.Data.RespData), &respData); err != nil {
|
||||
return fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
logger.Debugf("Jimeng AvatarRecognition Response: %+v", resp)
|
||||
if respData["status"] == 1 {
|
||||
return nil
|
||||
} else {
|
||||
return errors.New("不包含人、类人、拟人等主体")
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// QueryTask 查询任务结果
|
||||
func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
func (c *Client) QueryTask(req *QueryTaskRequest, action string) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
@@ -130,7 +232,7 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncGetResult", nil, string(jsonBody))
|
||||
respBody, statusCode, err := c.visual.Client.Json(action, nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
@@ -143,30 +245,37 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SubmitSyncTask 提交同步任务(仅用于文生图)
|
||||
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应,同步任务直接返回结果
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
// 检查响应错误代码
|
||||
if err := HandleResponseError(result.Code, result.Message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SubmitSyncImageTask 提交同步生图任务
|
||||
func (c *Client) SubmitSyncImageTask(req types.JimengTaskRequest) (*model.ImagesResponse, error) {
|
||||
// 配置火山引擎访问密钥,目前只支持API Key验证
|
||||
client := arkruntime.NewClientWithApiKey(c.config.ApiKey)
|
||||
// 构造生图请求
|
||||
sequentialImageGeneration := model.SequentialImageGeneration("disabled")
|
||||
generateReq := model.GenerateImagesRequest{
|
||||
Model: req.ReqKey, // 模型名称
|
||||
Prompt: req.Prompt, // 提示词
|
||||
Size: volcengine.String(req.Size), // 图片尺寸
|
||||
SequentialImageGeneration: &sequentialImageGeneration, // 禁用序列生成
|
||||
ResponseFormat: volcengine.String(model.GenerateImagesResponseFormatURL), // 响应格式为 URL
|
||||
Watermark: volcengine.Bool(false), // 不添加水印
|
||||
OptimizePrompt: volcengine.Bool(true), // 优化提示词
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
generateReq.Image = req.ImageUrls
|
||||
}
|
||||
// 调用生图 API
|
||||
resp, err := client.GenerateImages(context.Background(), generateReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
@@ -95,35 +96,29 @@ func (s *Service) processNextTask() {
|
||||
|
||||
if err := s.ProcessTask(jobId); err != nil {
|
||||
logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
|
||||
s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
|
||||
s.UpdateJobStatus(jobId, types.JMTaskStatusFailed, err.Error())
|
||||
} else {
|
||||
logger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTask 创建任务
|
||||
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
|
||||
func (s *Service) CreateTask(userId uint, req *types.JimengTaskRequest) (*model.JimengJob, error) {
|
||||
// 生成任务ID
|
||||
taskId := utils.RandString(20)
|
||||
|
||||
// 序列化任务参数
|
||||
paramsJson, err := json.Marshal(req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建任务记录
|
||||
job := &model.JimengJob{
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Type: req.Type,
|
||||
ReqKey: req.ReqKey,
|
||||
Prompt: req.Prompt,
|
||||
TaskParams: string(paramsJson),
|
||||
Status: model.JMTaskStatusInQueue,
|
||||
Power: req.Power,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Type: req.TaskType,
|
||||
ReqKey: req.ReqKey,
|
||||
Prompt: req.Prompt,
|
||||
Params: utils.JsonEncode(req),
|
||||
Status: types.JMTaskStatusInQueue,
|
||||
Power: req.Power,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
@@ -148,25 +143,71 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
}
|
||||
|
||||
// 更新任务状态为处理中
|
||||
if err := s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, ""); err != nil {
|
||||
if err := s.UpdateJobStatus(job.Id, types.JMTaskStatusGenerating, ""); err != nil {
|
||||
return fmt.Errorf("update job status failed: %w", err)
|
||||
}
|
||||
|
||||
// 解析任务参数
|
||||
var req types.JimengTaskRequest
|
||||
err := utils.JsonDecode(job.Params, &req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建请求并提交任务
|
||||
req, err := s.buildTaskRequest(&job)
|
||||
params, err := s.buildTaskRequest(&req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
||||
}
|
||||
|
||||
logger.Infof("提交即梦任务: %+v", req)
|
||||
// 数字人任务,先识别主体
|
||||
if req.TaskType == types.JMTaskTypeVirtualHuman {
|
||||
if err := s.client.AvatarRecognition(req.ImageUrls[0], req.RecognizeKey); err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("avatar recognition failed: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
// 同步任务 ,后台执行
|
||||
if req.ReqKey == DoubaoSeedream40ReqKey {
|
||||
go func() {
|
||||
resp, err := s.client.SubmitSyncImageTask(req)
|
||||
if err != nil {
|
||||
_ = s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
return
|
||||
}
|
||||
logger.Infof("同步任务提交成功: %+v", resp)
|
||||
// 更新原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
updates := map[string]any{
|
||||
"raw_data": string(rawData),
|
||||
}
|
||||
if resp.Error != nil {
|
||||
updates["status"] = types.JMTaskStatusFailed
|
||||
updates["err_msg"] = resp.Error.Message
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新任务状态
|
||||
updates["status"] = types.JMTaskStatusSuccess
|
||||
// 下载图片
|
||||
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(*resp.Data[0].Url, ".png", false)
|
||||
if err == nil {
|
||||
updates["img_url"] = imgUrl
|
||||
}
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debugf("提交即梦任务: %+v", params)
|
||||
// 异步任务 ,前台执行
|
||||
resp, err := s.client.SubmitTask(params)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
if resp.Code != CodeSuccess {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
@@ -184,172 +225,51 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
}
|
||||
|
||||
// buildTaskRequest 构建任务请求(统一的参数解析)
|
||||
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
|
||||
// 解析任务参数
|
||||
func (s *Service) buildTaskRequest(req *types.JimengTaskRequest) (map[string]any, error) {
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
err := utils.JsonDecode(utils.JsonEncode(req), ¶ms)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建基础请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 根据任务类型设置特定参数
|
||||
switch job.Type {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
s.setTextToImageParams(req, params)
|
||||
case model.JMTaskTypeImageToImage:
|
||||
s.setImageToImageParams(req, params)
|
||||
case model.JMTaskTypeImageEdit:
|
||||
s.setImageEditParams(req, params)
|
||||
case model.JMTaskTypeImageEffects:
|
||||
s.setImageEffectsParams(req, params)
|
||||
case model.JMTaskTypeTextToVideo:
|
||||
s.setTextToVideoParams(req, params)
|
||||
case model.JMTaskTypeImageToVideo:
|
||||
s.setImageToVideoParams(req, params)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// setTextToImageParams 设置文生图参数
|
||||
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
if usePreLlm, ok := params["use_pre_llm"]; ok {
|
||||
if usePreLlmVal, ok := usePreLlm.(bool); ok {
|
||||
req.UsePreLLM = usePreLlmVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setImageToImageParams 设置图生图参数
|
||||
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput, ok := params["image_input"].(string); ok {
|
||||
req.ImageInput = imageInput
|
||||
}
|
||||
if gpen, ok := params["gpen"]; ok {
|
||||
if gpenVal, ok := gpen.(float64); ok {
|
||||
req.Gpen = gpenVal
|
||||
}
|
||||
}
|
||||
if skin, ok := params["skin"]; ok {
|
||||
if skinVal, ok := skin.(float64); ok {
|
||||
req.Skin = skinVal
|
||||
}
|
||||
}
|
||||
if skinUnifi, ok := params["skin_unifi"]; ok {
|
||||
if skinUnifiVal, ok := skinUnifi.(float64); ok {
|
||||
req.SkinUnifi = skinUnifiVal
|
||||
}
|
||||
}
|
||||
if genMode, ok := params["gen_mode"].(string); ok {
|
||||
req.GenMode = genMode
|
||||
}
|
||||
s.setCommonParams(req, params) // 复用通用参数
|
||||
}
|
||||
|
||||
// setImageEditParams 设置图像编辑参数
|
||||
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
||||
for _, url := range imageUrls {
|
||||
if urlStr, ok := url.(string); ok {
|
||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
||||
// 把 size 转成 width 和 height
|
||||
if size, ok := params["size"]; ok {
|
||||
if sizeStr, ok := size.(string); ok {
|
||||
if strings.Contains(sizeStr, "x") {
|
||||
sizes := strings.Split(sizeStr, "x")
|
||||
params["width"] = sizes[0]
|
||||
params["height"] = sizes[1]
|
||||
}
|
||||
}
|
||||
delete(params, "size")
|
||||
}
|
||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
||||
for _, data := range binaryData {
|
||||
if dataStr, ok := data.(string); ok {
|
||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageEffectsParams 设置图像特效参数
|
||||
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput1, ok := params["image_input1"].(string); ok {
|
||||
req.ImageInput1 = imageInput1
|
||||
}
|
||||
if templateId, ok := params["template_id"].(string); ok {
|
||||
req.TemplateId = templateId
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
// duration 转成 frames
|
||||
if duration, ok := params["duration"]; ok {
|
||||
if secs, ok := duration.(int); ok {
|
||||
params["frames"] = secs*24 + 1
|
||||
}
|
||||
delete(params, "duration")
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setTextToVideoParams 设置文生视频参数
|
||||
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
// 单独处理图片特效任务
|
||||
if req.ReqKey == ImageEffectReqKey {
|
||||
params["image_input1"] = req.ImageUrls[0]
|
||||
delete(params, "image_urls")
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageToVideoParams 设置图生视频参数
|
||||
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
// 动作迁移,数字人任务参数处理
|
||||
if req.TaskType == types.JMTaskTypeVirtualHuman || req.TaskType == types.JMTaskTypeActionTransfer {
|
||||
params["image_url"] = req.ImageUrls[0]
|
||||
delete(params, "image_urls")
|
||||
}
|
||||
if req.RecognizeKey != "" {
|
||||
delete(params, "recognize_key")
|
||||
}
|
||||
}
|
||||
|
||||
// setCommonParams 设置通用参数(seed, width, height等)
|
||||
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
// 删除多余参数,剩下的就是各个任务自己专有参数了
|
||||
delete(params, "type")
|
||||
delete(params, "power")
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// pollTaskStatus 轮询任务状态
|
||||
@@ -357,7 +277,7 @@ func (s *Service) pollTaskStatus() {
|
||||
|
||||
for {
|
||||
var jobs []model.JimengJob
|
||||
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
|
||||
s.db.Where("status IN (?)", []types.JMTaskStatus{types.JMTaskStatusGenerating, types.JMTaskStatusInQueue}).Find(&jobs)
|
||||
if len(jobs) == 0 {
|
||||
logger.Debugf("no jimeng task to poll, sleep 10s")
|
||||
time.Sleep(10 * time.Second)
|
||||
@@ -371,12 +291,17 @@ func (s *Service) pollTaskStatus() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 豆包生图 4.0 是同步任务,不需要轮询
|
||||
if job.ReqKey == DoubaoSeedream40ReqKey {
|
||||
continue
|
||||
}
|
||||
|
||||
// 查询任务状态
|
||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
TaskId: job.TaskId,
|
||||
ReqJson: `{"return_url":true}`,
|
||||
})
|
||||
}, ASyncActionGetResult)
|
||||
|
||||
if err != nil {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
|
||||
@@ -387,13 +312,13 @@ func (s *Service) pollTaskStatus() {
|
||||
rawData, _ := json.Marshal(resp)
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
|
||||
|
||||
if resp.Code != 10000 {
|
||||
if resp.Code != CodeSuccess {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||
continue
|
||||
}
|
||||
|
||||
switch resp.Data.Status {
|
||||
case model.JMTaskStatusDone:
|
||||
case types.JMTaskStatusDone:
|
||||
// 判断任务是否成功
|
||||
if resp.Message != "Success" {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
|
||||
@@ -402,7 +327,7 @@ func (s *Service) pollTaskStatus() {
|
||||
|
||||
// 任务完成,更新结果
|
||||
updates := map[string]any{
|
||||
"status": model.JMTaskStatusSuccess,
|
||||
"status": types.JMTaskStatusSuccess,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
@@ -425,15 +350,15 @@ func (s *Service) pollTaskStatus() {
|
||||
}
|
||||
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
|
||||
case types.JMTaskStatusInQueue, types.JMTaskStatusGenerating:
|
||||
// 任务处理中
|
||||
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
|
||||
s.UpdateJobStatus(job.Id, types.JMTaskStatusGenerating, "")
|
||||
|
||||
case model.JMTaskStatusNotFound:
|
||||
case types.JMTaskStatusNotFound:
|
||||
// 任务未找到
|
||||
s.handleTaskError(job.Id, "task not found")
|
||||
|
||||
case model.JMTaskStatusExpired:
|
||||
case types.JMTaskStatusExpired:
|
||||
continue
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
@@ -448,7 +373,7 @@ func (s *Service) pollTaskStatus() {
|
||||
}
|
||||
|
||||
// UpdateJobStatus 更新任务状态
|
||||
func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg string) error {
|
||||
func (s *Service) UpdateJobStatus(jobId uint, status types.JMTaskStatus, errMsg string) error {
|
||||
updates := map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
@@ -462,7 +387,7 @@ func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg
|
||||
// handleTaskError 处理任务错误
|
||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||
logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
||||
return s.UpdateJobStatus(jobId, types.JMTaskStatusFailed, errMsg)
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列(用于手动重试)
|
||||
@@ -473,8 +398,8 @@ func (s *Service) PushTaskToQueue(jobId uint) error {
|
||||
// GetTaskStats 获取任务统计信息
|
||||
func (s *Service) GetTaskStats() (map[string]any, error) {
|
||||
type StatResult struct {
|
||||
Status string `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
Status types.JMTaskStatus `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
@@ -496,7 +421,7 @@ func (s *Service) GetTaskStats() (map[string]any, error) {
|
||||
|
||||
for _, stat := range stats {
|
||||
result["total"] = result["total"].(int64) + stat.Count
|
||||
result[stat.Status] = stat.Count
|
||||
result[string(stat.Status)] = stat.Count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
@@ -1,43 +1,9 @@
|
||||
package jimeng
|
||||
|
||||
import "geekai/store/model"
|
||||
|
||||
// ReqKey 常量定义
|
||||
const (
|
||||
ReqKeyTextToImage = "high_aes_general_v30l_zt2i" // 文生图
|
||||
ReqKeyImageToImagePortrait = "i2i_portrait_photo" // 图生图人像写真
|
||||
ReqKeyImageEdit = "seededit_v3.0" // 图像编辑
|
||||
ReqKeyImageEffects = "i2i_multi_style_zx2x" // 图像特效
|
||||
ReqKeyTextToVideo = "jimeng_vgfm_t2v_l20" // 文生视频
|
||||
ReqKeyImageToVideo = "jimeng_vgfm_i2v_l20" // 图生视频
|
||||
import (
|
||||
"geekai/core/types"
|
||||
)
|
||||
|
||||
// SubmitTaskRequest 提交任务请求
|
||||
type SubmitTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
// 文生图参数
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Scale float64 `json:"scale,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
||||
// 图生图参数
|
||||
ImageInput string `json:"image_input,omitempty"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||
Gpen float64 `json:"gpen,omitempty"`
|
||||
Skin float64 `json:"skin,omitempty"`
|
||||
SkinUnifi float64 `json:"skin_unifi,omitempty"`
|
||||
GenMode string `json:"gen_mode,omitempty"`
|
||||
// 图像编辑参数
|
||||
// 图像特效参数
|
||||
ImageInput1 string `json:"image_input1,omitempty"`
|
||||
TemplateId string `json:"template_id,omitempty"`
|
||||
// 视频生成参数
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
}
|
||||
|
||||
// SubmitTaskResponse 提交任务响应
|
||||
type SubmitTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
@@ -73,7 +39,7 @@ type QueryTaskResponse struct {
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
RespData string `json:"resp_data"`
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
Status types.JMTaskStatus `json:"status"`
|
||||
LlmResult string `json:"llm_result"`
|
||||
PeResult string `json:"pe_result"`
|
||||
PredictTagsResult string `json:"predict_tags_result"`
|
||||
@@ -83,9 +49,73 @@ type QueryTaskResponse struct {
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
const CodeSuccess = 10000
|
||||
|
||||
// 即梦AI错误代码常量
|
||||
const (
|
||||
// 成功
|
||||
ECSuccess = 10000
|
||||
|
||||
// 请求参数错误 (50200-50215)
|
||||
ECReqInvalidArgs = 50200 // 参数错误
|
||||
ECReqMissingArgs = 50201 // 缺少参数
|
||||
ECParseArgs = 50204 // 参数类型错误/参数缺失
|
||||
ECImageSizeLimited = 50205 // 图像尺寸超过限制
|
||||
ECImageEmpty = 50206 // 请求参数中没有获取到图像
|
||||
ECImageDecodeError = 50207 // 图像解码错误
|
||||
ECVideoEmpty = 50209 // 请求参数中没有获取到视频
|
||||
ECVideoDecodeError = 50210 // 视频解码错误
|
||||
ECVideoSizeLimited = 50211 // 视频尺寸超过限制
|
||||
ECReqBodySizeLimited = 50213 // 请求Body过大
|
||||
ECVideoTimeTooLong = 50214 // 输入视频时长过大
|
||||
ECRPCProcess = 50215 // 请求处理失败
|
||||
|
||||
// 算法服务错误 (60102-60208)
|
||||
ECJPFaceDetect = 60102 // 算法服务需要输入人脸图,但未检测到
|
||||
ECFSLeaderRiskError = 60208 // 输入图片中包含敏感信息,未通过审核
|
||||
|
||||
// 权限和系统错误 (50400-50501)
|
||||
ECAuth = 50400 // 权限校验失败
|
||||
ECReqMethod = 50402 // 访问的接口不存在
|
||||
ECReqLimit = 50429 // 超过调用QPS限制
|
||||
ECInternal = 50500 // 服务器内部错误
|
||||
ECRPCInternal = 50501 // 服务器内部RPC错误
|
||||
)
|
||||
|
||||
// 错误代码到错误信息的映射
|
||||
var errorCodeMessages = map[int]string{
|
||||
// 成功
|
||||
ECSuccess: "请求成功",
|
||||
|
||||
// 请求参数错误
|
||||
ECReqInvalidArgs: "参数错误,检查入参及MIME类型",
|
||||
ECReqMissingArgs: "缺少参数,检查入参及MIME类型",
|
||||
ECParseArgs: "参数类型错误/参数缺失,检查入参及MIME类型",
|
||||
ECImageSizeLimited: "图像尺寸超过限制,参考接口文档入参要求部分",
|
||||
ECImageEmpty: "请求参数中没有获取到图像,检查入参",
|
||||
ECImageDecodeError: "图像解码错误:没有获取到图像或者通过image_base64参数传递图像是base64解码错误,检查输出图片或检查base64是否错误携带前缀",
|
||||
ECVideoEmpty: "请求参数中没有获取到视频。输入为视频时可能返回此错误,检查入参",
|
||||
ECVideoDecodeError: "视频解码错误。输入为视频时可能返回此错误,检查输入视频是否不正确",
|
||||
ECVideoSizeLimited: "视频尺寸超过限制。输入为视频时可能返回此错误,检查输入视频大小",
|
||||
ECReqBodySizeLimited: "请求Body过大,超出接口限制,检查请求Body大小",
|
||||
ECVideoTimeTooLong: "输入视频时长过大,检查输入视频时长",
|
||||
ECRPCProcess: "由于输入的图片、视频、参数等不满足要求,导致请求处理失败。若接口文档中有具体说明,优先参考其具体含义,按照具体服务说明进行检查",
|
||||
|
||||
// 算法服务错误
|
||||
ECJPFaceDetect: "算法服务需要输入人脸图,但未检测到,检查输入图片是否包含人脸",
|
||||
ECFSLeaderRiskError: "输入图片中包含敏感信息,未通过审核",
|
||||
|
||||
// 权限和系统错误
|
||||
ECAuth: "权限校验失败,请检查是否已创建应用并开通服务或签名,参考接入指南及快速接入",
|
||||
ECReqMethod: "访问的接口不存在,检查入参",
|
||||
ECReqLimit: "超过调用QPS限制,购买QPS增项包",
|
||||
ECInternal: "服务器内部错误,提工单",
|
||||
ECRPCInternal: "服务器内部RPC错误,提工单",
|
||||
}
|
||||
|
||||
// CreateTaskRequest 创建任务请求
|
||||
type CreateTaskRequest struct {
|
||||
Type model.JMTaskType `json:"type"`
|
||||
Type types.JMTaskType `json:"type"`
|
||||
Prompt string `json:"prompt"`
|
||||
Params map[string]any `json:"params"`
|
||||
ReqKey string `json:"req_key"`
|
||||
@@ -93,53 +123,14 @@ type CreateTaskRequest struct {
|
||||
Power int `json:"power,omitempty"`
|
||||
}
|
||||
|
||||
// LogoInfo 水印信息
|
||||
type LogoInfo struct {
|
||||
AddLogo bool `json:"add_logo"`
|
||||
Position int `json:"position"`
|
||||
Language int `json:"language"`
|
||||
Opacity float64 `json:"opacity"`
|
||||
LogoTextContent string `json:"logo_text_content"`
|
||||
}
|
||||
|
||||
// ReqJsonConfig 查询配置
|
||||
type ReqJsonConfig struct {
|
||||
ReturnUrl bool `json:"return_url"`
|
||||
LogoInfo *LogoInfo `json:"logo_info,omitempty"`
|
||||
}
|
||||
|
||||
// ImageEffectTemplate 图像特效模板
|
||||
const (
|
||||
TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格
|
||||
TemplateIdMyWorld = "my_world" // 像素世界风
|
||||
TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版
|
||||
TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风
|
||||
TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版
|
||||
TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风
|
||||
TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风
|
||||
TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版
|
||||
TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰
|
||||
TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣
|
||||
TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡
|
||||
TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办
|
||||
TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里
|
||||
TemplateIdGlassBall = "glass_ball" // 玻璃球
|
||||
ImageEffectReqKey = "i2i_multi_style_zx2x"
|
||||
DoubaoSeedream40ReqKey = "doubao-seedream-4-0-250828"
|
||||
)
|
||||
|
||||
// AspectRatio 视频宽高比
|
||||
const (
|
||||
AspectRatio16_9 = "16:9" // 1280×720
|
||||
AspectRatio9_16 = "9:16" // 720×1280
|
||||
AspectRatio1_1 = "1:1" // 960×960
|
||||
AspectRatio4_3 = "4:3" // 960×720
|
||||
AspectRatio3_4 = "3:4" // 720×960
|
||||
AspectRatio21_9 = "21:9" // 1680×720
|
||||
AspectRatio9_21 = "9:21" // 720×1680
|
||||
)
|
||||
|
||||
// GenMode 生成模式
|
||||
const (
|
||||
GenModeCreative = "creative" // 提示词模式
|
||||
GenModeReference = "reference" // 全参考模式
|
||||
GenModeReferenceChar = "reference_char" // 人物参考模式
|
||||
ASyncActionSubmit = "CVSync2AsyncSubmitTask" // 异步提交任务
|
||||
SyncActionSubmit = "CVSubmitTask" // 同步提交任务
|
||||
ASyncActionGetResult = "CVSync2AsyncGetResult" // 异步获取结果
|
||||
SyncActionGetResult = "CVGetResult" // 同步获取结果
|
||||
)
|
||||
|
||||
@@ -159,8 +159,16 @@ func (s *MigrationService) MigrateConfigContent() error {
|
||||
|
||||
// 数据表迁移
|
||||
func (s *MigrationService) TableMigration() {
|
||||
|
||||
// v4.2.7 数据表迁移
|
||||
if s.db.Migrator().HasColumn(&model.JimengJob{}, "task_params") {
|
||||
s.db.Migrator().RenameColumn(&model.JimengJob{}, "task_params", "params")
|
||||
}
|
||||
|
||||
// 新数据表
|
||||
s.db.AutoMigrate(&model.Moderation{})
|
||||
if !s.db.Migrator().HasTable(&model.Moderation{}) {
|
||||
s.db.AutoMigrate(&model.Moderation{})
|
||||
}
|
||||
|
||||
// 订单字段整理
|
||||
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
|
||||
|
||||
@@ -57,13 +57,19 @@ func (s *UserService) DecreasePower(userId uint, power int, log model.PowerLog)
|
||||
defer s.lock.Unlock()
|
||||
|
||||
tx := s.db.Begin()
|
||||
var user model.User
|
||||
tx.Where("id", userId).First(&user)
|
||||
if user.Power < power {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("用户算力不足")
|
||||
}
|
||||
|
||||
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("扣减算力失败:%v", err)
|
||||
}
|
||||
var user model.User
|
||||
tx.Where("id", userId).First(&user)
|
||||
|
||||
err = tx.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
|
||||
Reference in New Issue
Block a user