mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-05-10 19:54:25 +08:00
acommpelish jimeng AI refactor for PC
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
@@ -10,6 +11,9 @@ import (
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
||||
)
|
||||
|
||||
// Client 即梦API客户端
|
||||
@@ -94,13 +98,19 @@ func (c *Client) testConnection() error {
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
|
||||
func (c *Client) SubmitTask(req map[string]any) (*SubmitTaskResponse, error) {
|
||||
// 直接将请求转为map[string]interface{}
|
||||
reqBodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 单独处理图片特效任务
|
||||
if req["req_key"] == ImageEffectReqKey {
|
||||
req["image_input1"] = req["image_urls"].([]any)[0]
|
||||
delete(req, "image_urls")
|
||||
}
|
||||
|
||||
// 直接使用序列化后的字节
|
||||
jsonBody := reqBodyBytes
|
||||
|
||||
@@ -146,27 +156,29 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SubmitSyncTask 提交同步任务(仅用于文生图)
|
||||
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
// SubmitSyncImageTask 提交同步生图任务
|
||||
func (c *Client) SubmitSyncImageTask(req types.JimengTaskRequest) (*model.ImagesResponse, error) {
|
||||
// 配置火山引擎访问密钥,目前只支持API Key验证
|
||||
client := arkruntime.NewClientWithApiKey(c.config.ApiKey)
|
||||
// 构造生图请求
|
||||
sequentialImageGeneration := model.SequentialImageGeneration("disabled")
|
||||
generateReq := model.GenerateImagesRequest{
|
||||
Model: req.ReqKey, // 模型名称
|
||||
Prompt: req.Prompt, // 提示词
|
||||
Size: volcengine.String(req.Size), // 图片尺寸
|
||||
SequentialImageGeneration: &sequentialImageGeneration, // 禁用序列生成
|
||||
ResponseFormat: volcengine.String(model.GenerateImagesResponseFormatURL), // 响应格式为 URL
|
||||
Watermark: volcengine.Bool(false), // 不添加水印
|
||||
OptimizePrompt: volcengine.Bool(true), // 优化提示词
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
generateReq.Image = req.ImageUrls
|
||||
}
|
||||
// 调用生图 API
|
||||
resp, err := client.GenerateImages(context.Background(), generateReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应,同步任务直接返回结果
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user