mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-05-10 19:54:25 +08:00
acommpelish jimeng AI refactor for PC
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
@@ -10,6 +11,9 @@ import (
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
||||
)
|
||||
|
||||
// Client 即梦API客户端
|
||||
@@ -94,13 +98,19 @@ func (c *Client) testConnection() error {
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
|
||||
func (c *Client) SubmitTask(req map[string]any) (*SubmitTaskResponse, error) {
|
||||
// 直接将请求转为map[string]interface{}
|
||||
reqBodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 单独处理图片特效任务
|
||||
if req["req_key"] == ImageEffectReqKey {
|
||||
req["image_input1"] = req["image_urls"].([]any)[0]
|
||||
delete(req, "image_urls")
|
||||
}
|
||||
|
||||
// 直接使用序列化后的字节
|
||||
jsonBody := reqBodyBytes
|
||||
|
||||
@@ -146,27 +156,29 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SubmitSyncTask 提交同步任务(仅用于文生图)
|
||||
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
// SubmitSyncImageTask 提交同步生图任务
|
||||
func (c *Client) SubmitSyncImageTask(req types.JimengTaskRequest) (*model.ImagesResponse, error) {
|
||||
// 配置火山引擎访问密钥,目前只支持API Key验证
|
||||
client := arkruntime.NewClientWithApiKey(c.config.ApiKey)
|
||||
// 构造生图请求
|
||||
sequentialImageGeneration := model.SequentialImageGeneration("disabled")
|
||||
generateReq := model.GenerateImagesRequest{
|
||||
Model: req.ReqKey, // 模型名称
|
||||
Prompt: req.Prompt, // 提示词
|
||||
Size: volcengine.String(req.Size), // 图片尺寸
|
||||
SequentialImageGeneration: &sequentialImageGeneration, // 禁用序列生成
|
||||
ResponseFormat: volcengine.String(model.GenerateImagesResponseFormatURL), // 响应格式为 URL
|
||||
Watermark: volcengine.Bool(false), // 不添加水印
|
||||
OptimizePrompt: volcengine.Bool(true), // 优化提示词
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
generateReq.Image = req.ImageUrls
|
||||
}
|
||||
// 调用生图 API
|
||||
resp, err := client.GenerateImages(context.Background(), generateReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应,同步任务直接返回结果
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -103,24 +103,18 @@ func (s *Service) processNextTask() {
|
||||
}
|
||||
|
||||
// CreateTask 创建任务
|
||||
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
|
||||
func (s *Service) CreateTask(userId uint, req *types.JimengTaskRequest) (*model.JimengJob, error) {
|
||||
// 生成任务ID
|
||||
taskId := utils.RandString(20)
|
||||
|
||||
// 序列化任务参数
|
||||
paramsJson, err := json.Marshal(req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建任务记录
|
||||
job := &model.JimengJob{
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Type: req.Type,
|
||||
Type: req.TaskType,
|
||||
ReqKey: req.ReqKey,
|
||||
Prompt: req.Prompt,
|
||||
Params: string(paramsJson),
|
||||
Params: utils.JsonEncode(req),
|
||||
Status: types.JMTaskStatusInQueue,
|
||||
Power: req.Power,
|
||||
CreatedAt: time.Now(),
|
||||
@@ -153,21 +147,61 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
return fmt.Errorf("update job status failed: %w", err)
|
||||
}
|
||||
|
||||
// 解析任务参数
|
||||
var req types.JimengTaskRequest
|
||||
err := utils.JsonDecode(job.Params, &req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建请求并提交任务
|
||||
req, err := s.buildTaskRequest(&job)
|
||||
params, err := s.buildTaskRequest(&req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
||||
}
|
||||
|
||||
logger.Infof("提交即梦任务: %+v", req)
|
||||
logger.Debugf("提交即梦任务: %+v", params)
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
// 同步任务 ,后台执行
|
||||
if req.ReqKey == DoubaoSeedream40ReqKey {
|
||||
go func() {
|
||||
resp, err := s.client.SubmitSyncImageTask(req)
|
||||
if err != nil {
|
||||
_ = s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
return
|
||||
}
|
||||
logger.Infof("同步任务提交成功: %+v", resp)
|
||||
// 更新原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
updates := map[string]any{
|
||||
"raw_data": string(rawData),
|
||||
}
|
||||
if resp.Error != nil {
|
||||
updates["status"] = types.JMTaskStatusFailed
|
||||
updates["err_msg"] = resp.Error.Message
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新任务状态
|
||||
updates["status"] = types.JMTaskStatusSuccess
|
||||
// 下载图片
|
||||
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(*resp.Data[0].Url, ".png", false)
|
||||
if err == nil {
|
||||
updates["img_url"] = imgUrl
|
||||
}
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 异步任务 ,前台执行
|
||||
resp, err := s.client.SubmitTask(params)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
if resp.Code != CodeSuccess {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
@@ -185,168 +219,36 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
}
|
||||
|
||||
// buildTaskRequest 构建任务请求(统一的参数解析)
|
||||
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
|
||||
// 解析任务参数
|
||||
func (s *Service) buildTaskRequest(req *types.JimengTaskRequest) (map[string]any, error) {
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.Params), ¶ms); err != nil {
|
||||
err := utils.JsonDecode(utils.JsonEncode(req), ¶ms)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建基础请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 根据任务类型设置特定参数
|
||||
switch job.Type {
|
||||
case types.JMTaskTypeImage:
|
||||
s.setTextToImageParams(req, params)
|
||||
case types.JMTaskTypeVideo:
|
||||
s.setImageToImageParams(req, params)
|
||||
case types.JMTaskTypeVirtualHuman:
|
||||
s.setImageEditParams(req, params)
|
||||
case types.JMTaskTypeActionTransfer:
|
||||
s.setImageEffectsParams(req, params)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// setTextToImageParams 设置文生图参数
|
||||
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
if usePreLlm, ok := params["use_pre_llm"]; ok {
|
||||
if usePreLlmVal, ok := usePreLlm.(bool); ok {
|
||||
req.UsePreLLM = usePreLlmVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setImageToImageParams 设置图生图参数
|
||||
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput, ok := params["image_input"].(string); ok {
|
||||
req.ImageInput = imageInput
|
||||
}
|
||||
if gpen, ok := params["gpen"]; ok {
|
||||
if gpenVal, ok := gpen.(float64); ok {
|
||||
req.Gpen = gpenVal
|
||||
}
|
||||
}
|
||||
if skin, ok := params["skin"]; ok {
|
||||
if skinVal, ok := skin.(float64); ok {
|
||||
req.Skin = skinVal
|
||||
}
|
||||
}
|
||||
if skinUnifi, ok := params["skin_unifi"]; ok {
|
||||
if skinUnifiVal, ok := skinUnifi.(float64); ok {
|
||||
req.SkinUnifi = skinUnifiVal
|
||||
}
|
||||
}
|
||||
if genMode, ok := params["gen_mode"].(string); ok {
|
||||
req.GenMode = genMode
|
||||
}
|
||||
s.setCommonParams(req, params) // 复用通用参数
|
||||
}
|
||||
|
||||
// setImageEditParams 设置图像编辑参数
|
||||
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
||||
for _, url := range imageUrls {
|
||||
if urlStr, ok := url.(string); ok {
|
||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
||||
// 把 size 转成 width 和 height
|
||||
if size, ok := params["size"]; ok {
|
||||
if sizeStr, ok := size.(string); ok {
|
||||
if strings.Contains(sizeStr, "x") {
|
||||
sizes := strings.Split(sizeStr, "x")
|
||||
params["width"] = sizes[0]
|
||||
params["height"] = sizes[1]
|
||||
}
|
||||
}
|
||||
delete(params, "size")
|
||||
}
|
||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
||||
for _, data := range binaryData {
|
||||
if dataStr, ok := data.(string); ok {
|
||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageEffectsParams 设置图像特效参数
|
||||
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput1, ok := params["image_input1"].(string); ok {
|
||||
req.ImageInput1 = imageInput1
|
||||
}
|
||||
if templateId, ok := params["template_id"].(string); ok {
|
||||
req.TemplateId = templateId
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
// duration 转成 frames
|
||||
if duration, ok := params["duration"]; ok {
|
||||
if secs, ok := duration.(int); ok {
|
||||
params["frames"] = secs*24 + 1
|
||||
}
|
||||
delete(params, "duration")
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setTextToVideoParams 设置文生视频参数
|
||||
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageToVideoParams 设置图生视频参数
|
||||
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
}
|
||||
|
||||
// setCommonParams 设置通用参数(seed, width, height等)
|
||||
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
// 删除多余参数,剩下的就是各个任务自己专有参数了
|
||||
delete(params, "type")
|
||||
delete(params, "power")
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// pollTaskStatus 轮询任务状态
|
||||
@@ -368,6 +270,11 @@ func (s *Service) pollTaskStatus() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 豆包生图 4.0 是同步任务,不需要轮询
|
||||
if job.ReqKey == DoubaoSeedream40ReqKey {
|
||||
continue
|
||||
}
|
||||
|
||||
// 查询任务状态
|
||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
@@ -384,7 +291,7 @@ func (s *Service) pollTaskStatus() {
|
||||
rawData, _ := json.Marshal(resp)
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
|
||||
|
||||
if resp.Code != 10000 {
|
||||
if resp.Code != CodeSuccess {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -4,32 +4,6 @@ import (
|
||||
"geekai/core/types"
|
||||
)
|
||||
|
||||
// SubmitTaskRequest 提交任务请求
|
||||
type SubmitTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
// 文生图参数
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Scale float64 `json:"scale,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
||||
// 图生图参数
|
||||
ImageInput string `json:"image_input,omitempty"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||
Gpen float64 `json:"gpen,omitempty"`
|
||||
Skin float64 `json:"skin,omitempty"`
|
||||
SkinUnifi float64 `json:"skin_unifi,omitempty"`
|
||||
GenMode string `json:"gen_mode,omitempty"`
|
||||
// 图像编辑参数
|
||||
// 图像特效参数
|
||||
ImageInput1 string `json:"image_input1,omitempty"`
|
||||
TemplateId string `json:"template_id,omitempty"`
|
||||
// 视频生成参数
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
}
|
||||
|
||||
// SubmitTaskResponse 提交任务响应
|
||||
type SubmitTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
@@ -75,6 +49,8 @@ type QueryTaskResponse struct {
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
const CodeSuccess = 10000
|
||||
|
||||
// CreateTaskRequest 创建任务请求
|
||||
type CreateTaskRequest struct {
|
||||
Type types.JMTaskType `json:"type"`
|
||||
@@ -84,3 +60,8 @@ type CreateTaskRequest struct {
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
Power int `json:"power,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
ImageEffectReqKey = "i2i_multi_style_zx2x"
|
||||
DoubaoSeedream40ReqKey = "doubao-seedream-4-0-250828"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user