acommpelish jimeng AI refactor for PC

This commit is contained in:
GeekMaster
2025-09-12 18:58:52 +08:00
parent c5badb3e13
commit 2c6eee7fc1
13 changed files with 1049 additions and 266 deletions

View File

@@ -4,7 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"gorm.io/gorm"
@@ -103,24 +103,18 @@ func (s *Service) processNextTask() {
}
// CreateTask 创建任务
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
func (s *Service) CreateTask(userId uint, req *types.JimengTaskRequest) (*model.JimengJob, error) {
// 生成任务ID
taskId := utils.RandString(20)
// 序列化任务参数
paramsJson, err := json.Marshal(req.Params)
if err != nil {
return nil, fmt.Errorf("marshal task params failed: %w", err)
}
// 创建任务记录
job := &model.JimengJob{
UserId: userId,
TaskId: taskId,
Type: req.Type,
Type: req.TaskType,
ReqKey: req.ReqKey,
Prompt: req.Prompt,
Params: string(paramsJson),
Params: utils.JsonEncode(req),
Status: types.JMTaskStatusInQueue,
Power: req.Power,
CreatedAt: time.Now(),
@@ -153,21 +147,61 @@ func (s *Service) ProcessTask(jobId uint) error {
return fmt.Errorf("update job status failed: %w", err)
}
// 解析任务参数
var req types.JimengTaskRequest
err := utils.JsonDecode(job.Params, &req)
if err != nil {
return fmt.Errorf("parse task params failed: %w", err)
}
// 构建请求并提交任务
req, err := s.buildTaskRequest(&job)
params, err := s.buildTaskRequest(&req)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
}
logger.Infof("提交即梦任务: %+v", req)
logger.Debugf("提交即梦任务: %+v", params)
// 提交异步任务
resp, err := s.client.SubmitTask(req)
// 同步任务 ,后台执行
if req.ReqKey == DoubaoSeedream40ReqKey {
go func() {
resp, err := s.client.SubmitSyncImageTask(req)
if err != nil {
_ = s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
return
}
logger.Infof("同步任务提交成功: %+v", resp)
// 更新原始数据
rawData, _ := json.Marshal(resp)
updates := map[string]any{
"raw_data": string(rawData),
}
if resp.Error != nil {
updates["status"] = types.JMTaskStatusFailed
updates["err_msg"] = resp.Error.Message
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
return
}
// 更新任务状态
updates["status"] = types.JMTaskStatusSuccess
// 下载图片
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(*resp.Data[0].Url, ".png", false)
if err == nil {
updates["img_url"] = imgUrl
}
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
}()
return nil
}
// 异步任务 ,前台执行
resp, err := s.client.SubmitTask(params)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
}
if resp.Code != 10000 {
if resp.Code != CodeSuccess {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
}
@@ -185,168 +219,36 @@ func (s *Service) ProcessTask(jobId uint) error {
}
// buildTaskRequest 构建任务请求(统一的参数解析)
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
// 解析任务参数
func (s *Service) buildTaskRequest(req *types.JimengTaskRequest) (map[string]any, error) {
var params map[string]any
if err := json.Unmarshal([]byte(job.Params), &params); err != nil {
err := utils.JsonDecode(utils.JsonEncode(req), &params)
if err != nil {
return nil, fmt.Errorf("parse task params failed: %w", err)
}
// 构建基础请求
req := &SubmitTaskRequest{
ReqKey: job.ReqKey,
Prompt: job.Prompt,
}
// 根据任务类型设置特定参数
switch job.Type {
case types.JMTaskTypeImage:
s.setTextToImageParams(req, params)
case types.JMTaskTypeVideo:
s.setImageToImageParams(req, params)
case types.JMTaskTypeVirtualHuman:
s.setImageEditParams(req, params)
case types.JMTaskTypeActionTransfer:
s.setImageEffectsParams(req, params)
default:
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
}
return req, nil
}
// setTextToImageParams 设置文生图参数
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
if scale, ok := params["scale"]; ok {
if scaleVal, ok := scale.(float64); ok {
req.Scale = scaleVal
}
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
}
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
if usePreLlm, ok := params["use_pre_llm"]; ok {
if usePreLlmVal, ok := usePreLlm.(bool); ok {
req.UsePreLLM = usePreLlmVal
}
}
}
// setImageToImageParams 设置图生图参数
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
if imageInput, ok := params["image_input"].(string); ok {
req.ImageInput = imageInput
}
if gpen, ok := params["gpen"]; ok {
if gpenVal, ok := gpen.(float64); ok {
req.Gpen = gpenVal
}
}
if skin, ok := params["skin"]; ok {
if skinVal, ok := skin.(float64); ok {
req.Skin = skinVal
}
}
if skinUnifi, ok := params["skin_unifi"]; ok {
if skinUnifiVal, ok := skinUnifi.(float64); ok {
req.SkinUnifi = skinUnifiVal
}
}
if genMode, ok := params["gen_mode"].(string); ok {
req.GenMode = genMode
}
s.setCommonParams(req, params) // 复用通用参数
}
// setImageEditParams 设置图像编辑参数
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
if imageUrls, ok := params["image_urls"].([]any); ok {
for _, url := range imageUrls {
if urlStr, ok := url.(string); ok {
req.ImageUrls = append(req.ImageUrls, urlStr)
// 把 size 转成 width 和 height
if size, ok := params["size"]; ok {
if sizeStr, ok := size.(string); ok {
if strings.Contains(sizeStr, "x") {
sizes := strings.Split(sizeStr, "x")
params["width"] = sizes[0]
params["height"] = sizes[1]
}
}
delete(params, "size")
}
if binaryData, ok := params["binary_data_base64"].([]any); ok {
for _, data := range binaryData {
if dataStr, ok := data.(string); ok {
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
}
}
}
if scale, ok := params["scale"]; ok {
if scaleVal, ok := scale.(float64); ok {
req.Scale = scaleVal
}
}
s.setCommonParams(req, params)
}
// setImageEffectsParams 设置图像特效参数
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
if imageInput1, ok := params["image_input1"].(string); ok {
req.ImageInput1 = imageInput1
}
if templateId, ok := params["template_id"].(string); ok {
req.TemplateId = templateId
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
// duration 转成 frames
if duration, ok := params["duration"]; ok {
if secs, ok := duration.(int); ok {
params["frames"] = secs*24 + 1
}
delete(params, "duration")
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
}
// setTextToVideoParams 设置文生视频参数
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
req.AspectRatio = aspectRatio
}
s.setCommonParams(req, params)
}
// setImageToVideoParams 设置图生视频参数
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
req.AspectRatio = aspectRatio
}
}
// setCommonParams 设置通用参数seed, width, height等
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
}
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
// 删除多余参数,剩下的就是各个任务自己专有参数
delete(params, "type")
delete(params, "power")
return params, nil
}
// pollTaskStatus 轮询任务状态
@@ -368,6 +270,11 @@ func (s *Service) pollTaskStatus() {
continue
}
// 豆包生图 4.0 是同步任务,不需要轮询
if job.ReqKey == DoubaoSeedream40ReqKey {
continue
}
// 查询任务状态
resp, err := s.client.QueryTask(&QueryTaskRequest{
ReqKey: job.ReqKey,
@@ -384,7 +291,7 @@ func (s *Service) pollTaskStatus() {
rawData, _ := json.Marshal(resp)
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
if resp.Code != 10000 {
if resp.Code != CodeSuccess {
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
continue
}