mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-22 19:14:29 +08:00
即梦AI绘图功能前端页面完成
This commit is contained in:
@@ -8,11 +8,8 @@ import (
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
"geekai/logger"
|
||||
)
|
||||
|
||||
var clientLogger = logger.GetLogger()
|
||||
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
visual *visual.Visual
|
||||
@@ -80,7 +77,7 @@ func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error)
|
||||
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||
looger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result SubmitTaskResponse
|
||||
@@ -105,7 +102,7 @@ func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||
looger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result QueryTaskResponse
|
||||
@@ -130,7 +127,7 @@ func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, err
|
||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
clientLogger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
looger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应,同步任务直接返回结果
|
||||
var result QueryTaskResponse
|
||||
@@ -139,4 +136,4 @@ func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (c *Consumer) consume() {
|
||||
// processTask 处理任务
|
||||
func (c *Consumer) processTask() {
|
||||
// 从队列中获取任务
|
||||
var task map[string]interface{}
|
||||
var task map[string]any
|
||||
if err := c.service.taskQueue.LPop(&task); err != nil {
|
||||
// 队列为空,等待1秒后重试
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"geekai/logger"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var serviceLogger = logger.GetLogger()
|
||||
var looger = logger2.GetLogger()
|
||||
|
||||
// Service 即梦服务
|
||||
type Service struct {
|
||||
@@ -29,8 +30,16 @@ type Service struct {
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, client *Client) *Service {
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
// 从数据库加载配置
|
||||
var config model.Config
|
||||
db.Where("name = ?", "Jimeng").First(&config)
|
||||
var jimengConfig types.JimengConfig
|
||||
if config.Id > 0 {
|
||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
}
|
||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
||||
return &Service{
|
||||
db: db,
|
||||
redis: redisCli,
|
||||
@@ -99,7 +108,7 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
return s.processTextToImage(&job)
|
||||
case model.JMTaskTypeImageToImage:
|
||||
return s.processImageToImagePortrait(&job)
|
||||
return s.processImageToImage(&job)
|
||||
case model.JMTaskTypeImageEdit:
|
||||
return s.processImageEdit(&job)
|
||||
case model.JMTaskTypeImageEffects:
|
||||
@@ -171,15 +180,15 @@ func (s *Service) processTextToImage(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
}
|
||||
|
||||
// processImageToImagePortrait 处理图生图人像写真任务
|
||||
func (s *Service) processImageToImagePortrait(job *model.JimengJob) error {
|
||||
// processImageToImage 处理图生图任务
|
||||
func (s *Service) processImageToImage(job *model.JimengJob) error {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
@@ -249,7 +258,7 @@ func (s *Service) processImageToImagePortrait(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -315,7 +324,7 @@ func (s *Service) processImageEdit(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -370,7 +379,7 @@ func (s *Service) processImageEffects(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -418,7 +427,7 @@ func (s *Service) processTextToVideo(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -482,7 +491,7 @@ func (s *Service) processImageToVideo(job *model.JimengJob) error {
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
looger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
@@ -505,7 +514,7 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
serviceLogger.Errorf("query jimeng task status failed: %v", err)
|
||||
looger.Errorf("query jimeng task status failed: %v", err)
|
||||
retryCount++
|
||||
continue
|
||||
}
|
||||
@@ -555,7 +564,7 @@ func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
||||
return s.handleTaskError(jobId, fmt.Sprintf("task not found or expired: %s", resp.Data.Status))
|
||||
|
||||
default:
|
||||
serviceLogger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
looger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
|
||||
retryCount++
|
||||
@@ -587,7 +596,7 @@ func (s *Service) UpdateJobProgress(jobId uint, progress int) error {
|
||||
|
||||
// handleTaskError 处理任务错误
|
||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||
serviceLogger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
looger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
||||
}
|
||||
|
||||
@@ -635,13 +644,12 @@ func (s *Service) DeleteJob(jobId uint, userId uint) error {
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列
|
||||
func (s *Service) PushTaskToQueue(task map[string]interface{}) error {
|
||||
func (s *Service) PushTaskToQueue(task map[string]any) error {
|
||||
return s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
// TestConnection 测试即梦AI连接
|
||||
func (s *Service) TestConnection(accessKey, secretKey string) error {
|
||||
// 创建临时客户端进行测试
|
||||
// testConnection 测试即梦AI连接
|
||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||
testClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
@@ -655,13 +663,12 @@ func (s *Service) TestConnection(accessKey, secretKey string) error {
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if err.Error() == "unauthorized" || err.Error() == "access denied" {
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -671,9 +678,9 @@ func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
|
||||
newClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 测试新客户端是否可用
|
||||
err := s.TestConnection(accessKey, secretKey)
|
||||
err := s.testConnection(accessKey, secretKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("新配置测试失败: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新客户端
|
||||
@@ -709,18 +716,3 @@ func (s *Service) GetConfig() (*types.JimengConfig, error) {
|
||||
|
||||
return &jimengConfig, nil
|
||||
}
|
||||
|
||||
// LoadConfigFromDB 从数据库加载配置并更新客户端
|
||||
func (s *Service) LoadConfigFromDB() error {
|
||||
config, err := s.GetConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果配置中有AccessKey和SecretKey,则更新客户端
|
||||
if config.AccessKey != "" && config.SecretKey != "" {
|
||||
return s.UpdateClientConfig(config.AccessKey, config.SecretKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user