mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-23 19:44:29 +08:00
增加即梦AI功能页面
This commit is contained in:
@@ -49,7 +49,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.DallTask) {
|
||||
logger.Infof("add a new DALL-E task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push dall-e task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
|
||||
332
api/service/jimeng/client.go
Normal file
332
api/service/jimeng/client.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"geekai/logger"
|
||||
)
|
||||
|
||||
var clientLogger = logger.GetLogger()
|
||||
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
accessKey string
|
||||
secretKey string
|
||||
region string
|
||||
service string
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient 创建即梦API客户端
|
||||
func NewClient(accessKey, secretKey string) *Client {
|
||||
return &Client{
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
region: "cn-north-1",
|
||||
service: "cv",
|
||||
baseURL: "https://visual.volcengineapi.com",
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SubmitTask 提交任务
|
||||
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
|
||||
// 构建请求URL
|
||||
queryParams := map[string]string{
|
||||
"Action": "CVSync2AsyncSubmitTask",
|
||||
"Version": "2022-08-31",
|
||||
}
|
||||
|
||||
reqURL := c.buildURL(queryParams)
|
||||
|
||||
// 序列化请求体
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request body failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create http request failed: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 签名请求
|
||||
if err := c.signRequest(httpReq, reqBody); err != nil {
|
||||
return nil, fmt.Errorf("sign request failed: %w", err)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send http request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body failed: %w", err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result SubmitTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// QueryTask 查询任务
|
||||
func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 构建请求URL
|
||||
queryParams := map[string]string{
|
||||
"Action": "CVSync2AsyncGetResult",
|
||||
"Version": "2022-08-31",
|
||||
}
|
||||
|
||||
reqURL := c.buildURL(queryParams)
|
||||
|
||||
// 序列化请求体
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request body failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create http request failed: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 签名请求
|
||||
if err := c.signRequest(httpReq, reqBody); err != nil {
|
||||
return nil, fmt.Errorf("sign request failed: %w", err)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send http request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body failed: %w", err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SubmitSyncTask 提交同步任务(仅用于文生图)
|
||||
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 构建请求URL
|
||||
queryParams := map[string]string{
|
||||
"Action": "CVProcess",
|
||||
"Version": "2022-08-31",
|
||||
}
|
||||
|
||||
reqURL := c.buildURL(queryParams)
|
||||
|
||||
// 序列化请求体
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request body failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create http request failed: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 签名请求
|
||||
if err := c.signRequest(httpReq, reqBody); err != nil {
|
||||
return nil, fmt.Errorf("sign request failed: %w", err)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send http request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body failed: %w", err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// buildURL 构建请求URL
|
||||
func (c *Client) buildURL(queryParams map[string]string) string {
|
||||
u, _ := url.Parse(c.baseURL)
|
||||
q := u.Query()
|
||||
for k, v := range queryParams {
|
||||
q.Set(k, v)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// signRequest 签名请求
|
||||
func (c *Client) signRequest(req *http.Request, body []byte) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// 设置基本头部
|
||||
req.Header.Set("X-Date", now.Format("20060102T150405Z"))
|
||||
req.Header.Set("Host", req.URL.Host)
|
||||
|
||||
// 计算内容哈希
|
||||
contentHash := sha256.Sum256(body)
|
||||
req.Header.Set("X-Content-Sha256", hex.EncodeToString(contentHash[:]))
|
||||
|
||||
// 构建签名字符串
|
||||
canonicalRequest := c.buildCanonicalRequest(req)
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/request", now.Format("20060102"), c.region, c.service)
|
||||
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
|
||||
now.Format("20060102T150405Z"), credentialScope, sha256Hash(canonicalRequest))
|
||||
|
||||
// 计算签名
|
||||
signature := c.calculateSignature(stringToSign, now)
|
||||
|
||||
// 设置Authorization头部
|
||||
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
c.accessKey, credentialScope, c.getSignedHeaders(req), signature)
|
||||
req.Header.Set("Authorization", authorization)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildCanonicalRequest 构建规范请求
|
||||
func (c *Client) buildCanonicalRequest(req *http.Request) string {
|
||||
// HTTP方法
|
||||
method := req.Method
|
||||
|
||||
// 规范URI
|
||||
uri := req.URL.Path
|
||||
if uri == "" {
|
||||
uri = "/"
|
||||
}
|
||||
|
||||
// 规范查询字符串
|
||||
query := req.URL.Query()
|
||||
var queryParts []string
|
||||
for k, v := range query {
|
||||
for _, val := range v {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(val)))
|
||||
}
|
||||
}
|
||||
sort.Strings(queryParts)
|
||||
canonicalQuery := strings.Join(queryParts, "&")
|
||||
|
||||
// 规范头部
|
||||
var headerParts []string
|
||||
headers := make(map[string]string)
|
||||
for k, v := range req.Header {
|
||||
key := strings.ToLower(k)
|
||||
if len(v) > 0 {
|
||||
headers[key] = strings.TrimSpace(v[0])
|
||||
}
|
||||
}
|
||||
|
||||
var headerKeys []string
|
||||
for k := range headers {
|
||||
headerKeys = append(headerKeys, k)
|
||||
}
|
||||
sort.Strings(headerKeys)
|
||||
|
||||
for _, k := range headerKeys {
|
||||
headerParts = append(headerParts, fmt.Sprintf("%s:%s", k, headers[k]))
|
||||
}
|
||||
canonicalHeaders := strings.Join(headerParts, "\n") + "\n"
|
||||
|
||||
// 签名头部
|
||||
signedHeaders := c.getSignedHeaders(req)
|
||||
|
||||
// 载荷哈希
|
||||
payloadHash := req.Header.Get("X-Content-Sha256")
|
||||
|
||||
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
method, uri, canonicalQuery, canonicalHeaders, signedHeaders, payloadHash)
|
||||
}
|
||||
|
||||
// getSignedHeaders 获取签名头部
|
||||
func (c *Client) getSignedHeaders(req *http.Request) string {
|
||||
var headers []string
|
||||
for k := range req.Header {
|
||||
headers = append(headers, strings.ToLower(k))
|
||||
}
|
||||
sort.Strings(headers)
|
||||
return strings.Join(headers, ";")
|
||||
}
|
||||
|
||||
// calculateSignature 计算签名
|
||||
func (c *Client) calculateSignature(stringToSign string, t time.Time) string {
|
||||
kDate := hmacSha256([]byte("HMAC-SHA256"+c.secretKey), []byte(t.Format("20060102")))
|
||||
kRegion := hmacSha256(kDate, []byte(c.region))
|
||||
kService := hmacSha256(kRegion, []byte(c.service))
|
||||
kSigning := hmacSha256(kService, []byte("request"))
|
||||
signature := hmacSha256(kSigning, []byte(stringToSign))
|
||||
return hex.EncodeToString(signature)
|
||||
}
|
||||
|
||||
// hmacSha256 计算HMAC-SHA256
|
||||
func hmacSha256(key []byte, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// sha256Hash 计算SHA256哈希
|
||||
func sha256Hash(data string) string {
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
177
api/service/jimeng/consumer.go
Normal file
177
api/service/jimeng/consumer.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"geekai/logger"
|
||||
"geekai/store/model"
|
||||
)
|
||||
|
||||
var jimengLogger = logger.GetLogger()
|
||||
|
||||
// Consumer 即梦任务消费者
|
||||
type Consumer struct {
|
||||
service *Service
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewConsumer 创建即梦任务消费者
|
||||
func NewConsumer(service *Service) *Consumer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Consumer{
|
||||
service: service,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动消费者
|
||||
func (c *Consumer) Start() {
|
||||
jimengLogger.Info("Starting Jimeng task consumer...")
|
||||
go c.consume()
|
||||
}
|
||||
|
||||
// Stop 停止消费者
|
||||
func (c *Consumer) Stop() {
|
||||
jimengLogger.Info("Stopping Jimeng task consumer...")
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// consume 消费任务
|
||||
func (c *Consumer) consume() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
jimengLogger.Info("Jimeng task consumer stopped")
|
||||
return
|
||||
default:
|
||||
c.processTask()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processTask 处理任务
|
||||
func (c *Consumer) processTask() {
|
||||
// 从队列中获取任务
|
||||
var task map[string]interface{}
|
||||
if err := c.service.taskQueue.LPop(&task); err != nil {
|
||||
// 队列为空,等待1秒后重试
|
||||
time.Sleep(time.Second)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析任务
|
||||
jobIdFloat, ok := task["job_id"].(float64)
|
||||
if !ok {
|
||||
jimengLogger.Errorf("invalid job_id in task: %v", task)
|
||||
return
|
||||
}
|
||||
jobId := uint(jobIdFloat)
|
||||
|
||||
taskType, ok := task["type"].(string)
|
||||
if !ok {
|
||||
jimengLogger.Errorf("invalid task type in task: %v", task)
|
||||
return
|
||||
}
|
||||
|
||||
jimengLogger.Infof("Processing Jimeng task: job_id=%d, type=%s", jobId, taskType)
|
||||
|
||||
// 处理任务
|
||||
if err := c.service.ProcessTask(jobId); err != nil {
|
||||
jimengLogger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
|
||||
|
||||
// 任务失败,直接标记为失败状态,不进行重试
|
||||
c.service.UpdateJobStatus(jobId, model.JimengJobStatusFailed, err.Error())
|
||||
} else {
|
||||
jimengLogger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
|
||||
}
|
||||
}
|
||||
|
||||
// TaskQueueStatus 任务队列状态
|
||||
type TaskQueueStatus struct {
|
||||
QueueLength int `json:"queue_length"`
|
||||
ActiveTasks int `json:"active_tasks"`
|
||||
}
|
||||
|
||||
// GetQueueStatus 获取队列状态
|
||||
func (c *Consumer) GetQueueStatus() (*TaskQueueStatus, error) {
|
||||
// 获取队列长度
|
||||
length, err := c.service.taskQueue.Size()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取活跃任务数(正在处理的任务)
|
||||
activeTasks, err := c.service.GetPendingTaskCount(0) // 0表示所有用户
|
||||
if err != nil {
|
||||
activeTasks = 0
|
||||
}
|
||||
|
||||
return &TaskQueueStatus{
|
||||
QueueLength: int(length),
|
||||
ActiveTasks: int(activeTasks),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MonitorQueue 监控队列状态
|
||||
func (c *Consumer) MonitorQueue() {
|
||||
ticker := time.NewTicker(30 * time.Second) // 每30秒监控一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
status, err := c.GetQueueStatus()
|
||||
if err != nil {
|
||||
jimengLogger.Errorf("get queue status failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if status.QueueLength > 0 || status.ActiveTasks > 0 {
|
||||
jimengLogger.Infof("Jimeng queue status: queue_length=%d, active_tasks=%d",
|
||||
status.QueueLength, status.ActiveTasks)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列(用于手动重试)
|
||||
func (c *Consumer) PushTaskToQueue(task map[string]interface{}) error {
|
||||
return c.service.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
// GetTaskStats 获取任务统计信息
|
||||
func (c *Consumer) GetTaskStats() (map[string]interface{}, error) {
|
||||
type StatResult struct {
|
||||
Status string `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
err := c.service.db.Model(&model.JimengJob{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"total": int64(0),
|
||||
"completed": int64(0),
|
||||
"processing": int64(0),
|
||||
"failed": int64(0),
|
||||
"pending": int64(0),
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
result["total"] = result["total"].(int64) + stat.Count
|
||||
result[stat.Status] = stat.Count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
633
api/service/jimeng/service.go
Normal file
633
api/service/jimeng/service.go
Normal file
@@ -0,0 +1,633 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"geekai/logger"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var serviceLogger = logger.GetLogger()
|
||||
|
||||
// Service 即梦服务
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
taskQueue *store.RedisQueue
|
||||
client *Client
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, client *Client) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
return &Service{
|
||||
db: db,
|
||||
redis: redisCli,
|
||||
taskQueue: taskQueue,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTask 创建任务
|
||||
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
|
||||
// 生成任务ID
|
||||
taskId := utils.RandString(20)
|
||||
|
||||
// 序列化任务参数
|
||||
paramsJson, err := json.Marshal(req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建任务记录
|
||||
job := &model.JimengJob{
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Type: req.Type,
|
||||
ReqKey: req.ReqKey,
|
||||
Prompt: req.Prompt,
|
||||
TaskParams: string(paramsJson),
|
||||
Status: model.JimengJobStatusPending,
|
||||
Power: req.Power,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
if err := s.db.Create(job).Error; err != nil {
|
||||
return nil, fmt.Errorf("create jimeng job failed: %w", err)
|
||||
}
|
||||
|
||||
// 推送到任务队列
|
||||
task := map[string]any{
|
||||
"job_id": job.Id,
|
||||
"type": job.Type,
|
||||
}
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
return nil, fmt.Errorf("push jimeng task to queue failed: %w", err)
|
||||
}
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// ProcessTask 处理任务
|
||||
func (s *Service) ProcessTask(jobId uint) error {
|
||||
// 获取任务记录
|
||||
var job model.JimengJob
|
||||
if err := s.db.First(&job, jobId).Error; err != nil {
|
||||
return fmt.Errorf("get jimeng job failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新任务状态为处理中
|
||||
if err := s.UpdateJobStatus(job.Id, model.JimengJobStatusProcessing, ""); err != nil {
|
||||
return fmt.Errorf("update job status failed: %w", err)
|
||||
}
|
||||
|
||||
// 根据任务类型处理
|
||||
switch job.Type {
|
||||
case model.JimengJobTypeTextToImage:
|
||||
return s.processTextToImage(&job)
|
||||
case model.JimengJobTypeImageToImagePortrait:
|
||||
return s.processImageToImagePortrait(&job)
|
||||
case model.JimengJobTypeImageEdit:
|
||||
return s.processImageEdit(&job)
|
||||
case model.JimengJobTypeImageEffects:
|
||||
return s.processImageEffects(&job)
|
||||
case model.JimengJobTypeTextToVideo:
|
||||
return s.processTextToVideo(&job)
|
||||
case model.JimengJobTypeImageToVideo:
|
||||
return s.processImageToVideo(&job)
|
||||
default:
|
||||
return fmt.Errorf("unsupported task type: %s", job.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// processTextToImage 处理文生图任务
|
||||
func (s *Service) processTextToImage(job *model.JimengJob) error {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 设置参数
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
if usePreLlm, ok := params["use_pre_llm"]; ok {
|
||||
if usePreLlmVal, ok := usePreLlm.(bool); ok {
|
||||
req.UsePreLLM = usePreLlmVal
|
||||
}
|
||||
}
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
}
|
||||
|
||||
// processImageToImagePortrait 处理图生图人像写真任务
|
||||
func (s *Service) processImageToImagePortrait(job *model.JimengJob) error {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 设置图像输入
|
||||
if imageInput, ok := params["image_input"].(string); ok {
|
||||
req.ImageInput = imageInput
|
||||
}
|
||||
|
||||
// 设置其他参数
|
||||
if gpen, ok := params["gpen"]; ok {
|
||||
if gpenVal, ok := gpen.(float64); ok {
|
||||
req.Gpen = gpenVal
|
||||
}
|
||||
}
|
||||
if skin, ok := params["skin"]; ok {
|
||||
if skinVal, ok := skin.(float64); ok {
|
||||
req.Skin = skinVal
|
||||
}
|
||||
}
|
||||
if skinUnifi, ok := params["skin_unifi"]; ok {
|
||||
if skinUnifiVal, ok := skinUnifi.(float64); ok {
|
||||
req.SkinUnifi = skinUnifiVal
|
||||
}
|
||||
}
|
||||
if genMode, ok := params["gen_mode"].(string); ok {
|
||||
req.GenMode = genMode
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
}
|
||||
|
||||
// processImageEdit 处理图像编辑任务
|
||||
func (s *Service) processImageEdit(job *model.JimengJob) error {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 设置图像输入
|
||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
||||
for _, url := range imageUrls {
|
||||
if urlStr, ok := url.(string); ok {
|
||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
||||
for _, data := range binaryData {
|
||||
if dataStr, ok := data.(string); ok {
|
||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置其他参数
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
}
|
||||
|
||||
// processImageEffects 处理图像特效任务
|
||||
func (s *Service) processImageEffects(job *model.JimengJob) error {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
}
|
||||
|
||||
// 设置图像输入
|
||||
if imageInput1, ok := params["image_input1"].(string); ok {
|
||||
req.ImageInput1 = imageInput1
|
||||
}
|
||||
if templateId, ok := params["template_id"].(string); ok {
|
||||
req.TemplateId = templateId
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
}
|
||||
|
||||
// processTextToVideo 处理文生视频任务
|
||||
func (s *Service) processTextToVideo(job *model.JimengJob) error {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 设置参数
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
}
|
||||
|
||||
// processImageToVideo 处理图生视频任务
|
||||
func (s *Service) processImageToVideo(job *model.JimengJob) error {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 设置图像输入
|
||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
||||
for _, url := range imageUrls {
|
||||
if urlStr, ok := url.(string); ok {
|
||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
||||
for _, data := range binaryData {
|
||||
if dataStr, ok := data.(string); ok {
|
||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置其他参数
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
}
|
||||
|
||||
// pollTaskStatus 轮询任务状态
|
||||
func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
||||
maxRetries := 60 // 最大重试次数,60次 * 5秒 = 5分钟
|
||||
retryCount := 0
|
||||
|
||||
for retryCount < maxRetries {
|
||||
time.Sleep(5 * time.Second) // 等待5秒
|
||||
|
||||
// 查询任务状态
|
||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: reqKey,
|
||||
TaskId: taskId,
|
||||
ReqJson: `{"return_url":true}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
serviceLogger.Errorf("query jimeng task status failed: %v", err)
|
||||
retryCount++
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Update("raw_data", string(rawData))
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(jobId, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
switch resp.Data.Status {
|
||||
case TaskStatusDone:
|
||||
// 任务完成,更新结果
|
||||
updates := map[string]any{
|
||||
"status": model.JimengJobStatusCompleted,
|
||||
"progress": 100,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 设置结果URL
|
||||
if len(resp.Data.ImageUrls) > 0 {
|
||||
updates["img_url"] = resp.Data.ImageUrls[0]
|
||||
}
|
||||
if resp.Data.VideoUrl != "" {
|
||||
updates["video_url"] = resp.Data.VideoUrl
|
||||
}
|
||||
|
||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
|
||||
|
||||
case TaskStatusInQueue:
|
||||
// 任务在队列中
|
||||
s.UpdateJobProgress(jobId, 10)
|
||||
|
||||
case TaskStatusGenerating:
|
||||
// 任务处理中
|
||||
s.UpdateJobProgress(jobId, 50)
|
||||
|
||||
case TaskStatusNotFound, TaskStatusExpired:
|
||||
// 任务未找到或已过期
|
||||
return s.handleTaskError(jobId, fmt.Sprintf("task not found or expired: %s", resp.Data.Status))
|
||||
|
||||
default:
|
||||
serviceLogger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
|
||||
retryCount++
|
||||
}
|
||||
|
||||
// 超时处理
|
||||
return s.handleTaskError(jobId, "task timeout")
|
||||
}
|
||||
|
||||
// UpdateJobStatus 更新任务状态
|
||||
func (s *Service) UpdateJobStatus(jobId uint, status, errMsg string) error {
|
||||
updates := map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
if errMsg != "" {
|
||||
updates["err_msg"] = errMsg
|
||||
}
|
||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateJobProgress 更新任务进度
|
||||
func (s *Service) UpdateJobProgress(jobId uint, progress int) error {
|
||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(map[string]any{
|
||||
"progress": progress,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// handleTaskError 处理任务错误
|
||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||
serviceLogger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
return s.UpdateJobStatus(jobId, model.JimengJobStatusFailed, errMsg)
|
||||
}
|
||||
|
||||
// GetJob 获取任务
|
||||
func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||
var job model.JimengJob
|
||||
if err := s.db.First(&job, jobId).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// GetUserJobs 获取用户任务列表
|
||||
func (s *Service) GetUserJobs(userId uint, page, pageSize int) ([]*model.JimengJob, int64, error) {
|
||||
var jobs []*model.JimengJob
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&model.JimengJob{}).Where("user_id = ?", userId)
|
||||
|
||||
// 统计总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&jobs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return jobs, total, nil
|
||||
}
|
||||
|
||||
// GetPendingTaskCount 获取用户未完成任务数量
|
||||
func (s *Service) GetPendingTaskCount(userId uint) (int64, error) {
|
||||
var count int64
|
||||
err := s.db.Model(&model.JimengJob{}).Where("user_id = ? AND status IN (?)", userId,
|
||||
[]string{model.JimengJobStatusPending, model.JimengJobStatusProcessing}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteJob 删除任务
|
||||
func (s *Service) DeleteJob(jobId uint, userId uint) error {
|
||||
return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列
|
||||
func (s *Service) PushTaskToQueue(task map[string]interface{}) error {
|
||||
return s.taskQueue.RPush(task)
|
||||
}
|
||||
163
api/service/jimeng/types.go
Normal file
163
api/service/jimeng/types.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package jimeng
|
||||
|
||||
import "time"
|
||||
|
||||
// SubmitTaskRequest 提交任务请求
|
||||
type SubmitTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
// 文生图参数
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Scale float64 `json:"scale,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
||||
// 图生图参数
|
||||
ImageInput string `json:"image_input,omitempty"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||
Gpen float64 `json:"gpen,omitempty"`
|
||||
Skin float64 `json:"skin,omitempty"`
|
||||
SkinUnifi float64 `json:"skin_unifi,omitempty"`
|
||||
GenMode string `json:"gen_mode,omitempty"`
|
||||
// 图像编辑参数
|
||||
// 图像特效参数
|
||||
ImageInput1 string `json:"image_input1,omitempty"`
|
||||
TemplateId string `json:"template_id,omitempty"`
|
||||
// 视频生成参数
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
}
|
||||
|
||||
// SubmitTaskResponse 提交任务响应
|
||||
type SubmitTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// QueryTaskRequest 查询任务请求
|
||||
type QueryTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
TaskId string `json:"task_id"`
|
||||
ReqJson string `json:"req_json,omitempty"`
|
||||
}
|
||||
|
||||
// QueryTaskResponse 查询任务响应
|
||||
type QueryTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
Data struct {
|
||||
AlgorithmBaseResp struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMessage string `json:"status_message"`
|
||||
} `json:"algorithm_base_resp"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
RespData string `json:"resp_data"`
|
||||
Status string `json:"status"`
|
||||
LlmResult string `json:"llm_result"`
|
||||
PeResult string `json:"pe_result"`
|
||||
PredictTagsResult string `json:"predict_tags_result"`
|
||||
RephraserResult string `json:"rephraser_result"`
|
||||
VlmResult string `json:"vlm_result"`
|
||||
InferCtx interface{} `json:"infer_ctx"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// TaskStatus 任务状态
|
||||
const (
|
||||
TaskStatusInQueue = "in_queue" // 任务已提交
|
||||
TaskStatusGenerating = "generating" // 任务处理中
|
||||
TaskStatusDone = "done" // 处理完成
|
||||
TaskStatusNotFound = "not_found" // 任务未找到
|
||||
TaskStatusExpired = "expired" // 任务已过期
|
||||
)
|
||||
|
||||
// CreateTaskRequest 创建任务请求
|
||||
type CreateTaskRequest struct {
|
||||
Type string `json:"type"`
|
||||
Prompt string `json:"prompt"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
ReqKey string `json:"req_key"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
Power int `json:"power,omitempty"`
|
||||
}
|
||||
|
||||
// TaskInfo 任务信息
|
||||
type TaskInfo struct {
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
Type string `json:"type"`
|
||||
ReqKey string `json:"req_key"`
|
||||
Prompt string `json:"prompt"`
|
||||
TaskParams string `json:"task_params"`
|
||||
ImgURL string `json:"img_url"`
|
||||
VideoURL string `json:"video_url"`
|
||||
Progress int `json:"progress"`
|
||||
Status string `json:"status"`
|
||||
ErrMsg string `json:"err_msg"`
|
||||
Power int `json:"power"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// LogoInfo 水印信息
|
||||
type LogoInfo struct {
|
||||
AddLogo bool `json:"add_logo"`
|
||||
Position int `json:"position"`
|
||||
Language int `json:"language"`
|
||||
Opacity float64 `json:"opacity"`
|
||||
LogoTextContent string `json:"logo_text_content"`
|
||||
}
|
||||
|
||||
// ReqJsonConfig 查询配置
|
||||
type ReqJsonConfig struct {
|
||||
ReturnUrl bool `json:"return_url"`
|
||||
LogoInfo *LogoInfo `json:"logo_info,omitempty"`
|
||||
}
|
||||
|
||||
// ImageEffectTemplate 图像特效模板
|
||||
const (
|
||||
TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格
|
||||
TemplateIdMyWorld = "my_world" // 像素世界风
|
||||
TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版
|
||||
TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风
|
||||
TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版
|
||||
TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风
|
||||
TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风
|
||||
TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版
|
||||
TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰
|
||||
TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣
|
||||
TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡
|
||||
TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办
|
||||
TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里
|
||||
TemplateIdGlassBall = "glass_ball" // 玻璃球
|
||||
)
|
||||
|
||||
// AspectRatio 视频宽高比
|
||||
const (
|
||||
AspectRatio16_9 = "16:9" // 1280×720
|
||||
AspectRatio9_16 = "9:16" // 720×1280
|
||||
AspectRatio1_1 = "1:1" // 960×960
|
||||
AspectRatio4_3 = "4:3" // 960×720
|
||||
AspectRatio3_4 = "3:4" // 720×960
|
||||
AspectRatio21_9 = "21:9" // 1680×720
|
||||
AspectRatio9_21 = "9:21" // 720×1680
|
||||
)
|
||||
|
||||
// GenMode 生成模式
|
||||
const (
|
||||
GenModeCreative = "creative" // 提示词模式
|
||||
GenModeReference = "reference" // 全参考模式
|
||||
GenModeReferenceChar = "reference_char" // 人物参考模式
|
||||
)
|
||||
@@ -212,7 +212,9 @@ func (s *Service) DownloadImages() {
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.MjTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push mj task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
|
||||
@@ -253,7 +253,9 @@ func (s *Service) checkTaskProgress(apiKey model.ApiKey) (*TaskProgressResp, err
|
||||
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push sd task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||
|
||||
@@ -51,7 +51,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
|
||||
|
||||
func (s *Service) PushTask(task types.SunoTask) {
|
||||
logger.Infof("add a new Suno task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push suno task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
|
||||
@@ -51,7 +51,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
|
||||
|
||||
func (s *Service) PushTask(task types.VideoTask) {
|
||||
logger.Infof("add a new Video task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push video task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
|
||||
Reference in New Issue
Block a user