AI3D 功能完成

This commit is contained in:
GeekMaster
2025-09-04 18:36:49 +08:00
parent 53866d1461
commit 52d297624d
30 changed files with 829 additions and 969 deletions

View File

@@ -1,13 +1,18 @@
package ai3d
import (
"encoding/json"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/go-redis/redis/v8"
@@ -22,52 +27,81 @@ type Service struct {
taskQueue *store.RedisQueue
tencentClient *Tencent3DClient
giteeClient *Gitee3DClient
userService *service.UserService
uploadManager *oss.UploaderManager
}
// NewService 创建3D生成服务
func NewService(db *gorm.DB, redisCli *redis.Client, tencentClient *Tencent3DClient, giteeClient *Gitee3DClient) *Service {
func NewService(db *gorm.DB, redisCli *redis.Client, tencentClient *Tencent3DClient, giteeClient *Gitee3DClient, userService *service.UserService, uploadManager *oss.UploaderManager) *Service {
return &Service{
db: db,
taskQueue: store.NewRedisQueue("3D_Task_Queue", redisCli),
tencentClient: tencentClient,
giteeClient: giteeClient,
userService: userService,
uploadManager: uploadManager,
}
}
// CreateJob 创建3D生成任务
func (s *Service) CreateJob(userId uint, request vo.AI3DJobCreate) (*model.AI3DJob, error) {
// 创建任务记录
job := &model.AI3DJob{
UserId: userId,
Type: request.Type,
Power: request.Power,
Model: request.Model,
Status: types.AI3DJobStatusPending,
func (s *Service) CreateJob(userId uint, request vo.AI3DJobParams) (*model.AI3DJob, error) {
switch request.Type {
case types.AI3DTaskTypeGitee:
if s.giteeClient == nil {
return nil, fmt.Errorf("模力方舟 3D 服务未初始化")
}
if !s.giteeClient.GetConfig().Enabled {
return nil, fmt.Errorf("模力方舟 3D 服务未启用")
}
case types.AI3DTaskTypeTencent:
if s.tencentClient == nil {
return nil, fmt.Errorf("腾讯云 3D 服务未初始化")
}
if !s.tencentClient.GetConfig().Enabled {
return nil, fmt.Errorf("腾讯云 3D 服务未启用")
}
default:
return nil, fmt.Errorf("不支持的 3D 服务类型: %s", request.Type)
}
// 序列化参数
params := map[string]any{
"prompt": request.Prompt,
"image_url": request.ImageURL,
"model": request.Model,
"power": request.Power,
// 创建任务记录
job := &model.AI3DJob{
UserId: userId,
Type: request.Type,
Power: request.Power,
Model: request.Model,
Status: types.AI3DJobStatusPending,
PreviewURL: request.ImageURL,
}
paramsJSON, _ := json.Marshal(params)
job.Params = string(paramsJSON)
job.Params = utils.JsonEncode(request)
// 保存到数据库
if err := s.db.Create(job).Error; err != nil {
return nil, fmt.Errorf("failed to create 3D job: %v", err)
}
// 更新用户算力
err := s.userService.DecreasePower(userId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: job.Model,
Remark: fmt.Sprintf("创建3D任务消耗%d算力", job.Power),
})
if err != nil {
return nil, fmt.Errorf("failed to update user power: %v", err)
}
// 将任务添加到队列
s.PushTask(job)
request.JobId = job.Id
s.PushTask(request)
return job, nil
}
// PushTask 将任务添加到队列
func (s *Service) PushTask(job *model.AI3DJob) {
func (s *Service) PushTask(job vo.AI3DJobParams) {
logger.Infof("add a new 3D task to the queue: %+v", job)
if err := s.taskQueue.RPush(job); err != nil {
logger.Errorf("push 3D task to queue failed: %v", err)
@@ -76,72 +110,70 @@ func (s *Service) PushTask(job *model.AI3DJob) {
// Run 启动任务处理器
func (s *Service) Run() {
// 将数据库中未完成的任务加载到队列
var jobs []model.AI3DJob
s.db.Where("status IN ?", []string{types.AI3DJobStatusPending, types.AI3DJobStatusProcessing}).Find(&jobs)
for _, job := range jobs {
s.PushTask(&job)
}
logger.Info("Starting 3D job consumer...")
go func() {
for {
var job model.AI3DJob
err := s.taskQueue.LPop(&job)
var params vo.AI3DJobParams
err := s.taskQueue.LPop(&params)
if err != nil {
logger.Errorf("taking 3D task with error: %v", err)
continue
}
logger.Infof("handle a new 3D task: %+v", job)
logger.Infof("handle a new 3D task: %+v", params)
go func() {
if err := s.processJob(&job); err != nil {
if err := s.processJob(&params); err != nil {
logger.Errorf("error processing 3D job: %v", err)
s.updateJobStatus(&job, types.AI3DJobStatusFailed, 0, err.Error())
s.updateJobStatus(params.JobId, types.AI3DJobStatusFailed, err.Error())
}
}()
}
}()
go s.pollJobStatus()
}
// processJob 处理3D任务
func (s *Service) processJob(job *model.AI3DJob) error {
func (s *Service) processJob(params *vo.AI3DJobParams) error {
// 更新状态为处理中
s.updateJobStatus(job, types.AI3DJobStatusProcessing, 10, "")
// 解析参数
var params map[string]any
if err := json.Unmarshal([]byte(job.Params), &params); err != nil {
return fmt.Errorf("failed to parse job params: %v", err)
}
s.updateJobStatus(params.JobId, types.AI3DJobStatusProcessing, "")
var taskId string
var err error
// 根据类型选择客户端
switch job.Type {
case "tencent":
switch params.Type {
case types.AI3DTaskTypeTencent:
if s.tencentClient == nil {
return fmt.Errorf("tencent 3D client not initialized")
}
tencentParams := Tencent3DParams{
Prompt: s.getString(params, "prompt"),
ImageURL: s.getString(params, "image_url"),
ResultFormat: job.Model,
EnablePBR: false,
Prompt: params.Prompt,
ImageURL: params.ImageURL,
ResultFormat: params.FileFormat,
EnablePBR: params.EnablePBR,
}
taskId, err = s.tencentClient.SubmitJob(tencentParams)
case "gitee":
case types.AI3DTaskTypeGitee:
if s.giteeClient == nil {
return fmt.Errorf("gitee 3D client not initialized")
}
giteeParams := Gitee3DParams{
Prompt: s.getString(params, "prompt"),
ImageURL: s.getString(params, "image_url"),
ResultFormat: job.Model,
Model: params.Model,
Texture: params.Texture,
Seed: params.Seed,
NumInferenceSteps: params.NumInferenceSteps,
GuidanceScale: params.GuidanceScale,
OctreeResolution: params.OctreeResolution,
ImageURL: params.ImageURL,
}
if params.Model == "Hunyuan3D-2" {
giteeParams.Type = strings.ToLower(params.FileFormat)
} else {
giteeParams.FileFormat = strings.ToLower(params.FileFormat)
}
taskId, err = s.giteeClient.SubmitJob(giteeParams)
default:
return fmt.Errorf("unsupported 3D API type: %s", job.Type)
return fmt.Errorf("unsupported 3D API type: %s", params.Type)
}
if err != nil {
@@ -149,43 +181,65 @@ func (s *Service) processJob(job *model.AI3DJob) error {
}
// 更新任务ID
job.TaskId = taskId
s.db.Model(job).Update("task_id", taskId)
// 开始轮询任务状态
go s.pollJobStatus(job)
s.db.Model(model.AI3DJob{}).Where("id = ?", params.JobId).Update("task_id", taskId)
return nil
}
// pollJobStatus 轮询任务状态
func (s *Service) pollJobStatus(job *model.AI3DJob) {
func (s *Service) pollJobStatus() {
// 10秒轮询一次
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
result, err := s.queryJobStatus(job)
for range ticker.C {
var jobs []model.AI3DJob
s.db.Where("status IN (?)", []string{types.AI3DJobStatusProcessing, types.AI3DJobStatusPending}).Find(&jobs)
if len(jobs) == 0 {
logger.Debug("no 3D jobs to poll, sleep 10s")
continue
}
for _, job := range jobs {
// 15 分钟超时
if job.CreatedAt.Before(time.Now().Add(-20 * time.Minute)) {
s.updateJobStatus(job.Id, types.AI3DJobStatusFailed, "task timeout")
continue
}
result, err := s.queryJobStatus(&job)
if err != nil {
logger.Errorf("failed to query job status: %v", err)
continue
}
// 更新进度
s.updateJobStatus(job, result.Status, result.Progress, result.ErrorMsg)
// 如果任务完成或失败,停止轮询
if result.Status == types.AI3DJobStatusCompleted || result.Status == types.AI3DJobStatusFailed {
if result.Status == types.AI3DJobStatusCompleted {
// 更新结果文件URL
s.db.Model(job).Updates(map[string]interface{}{
"img_url": result.FileURL,
"preview_url": result.PreviewURL,
})
}
return
updates := map[string]any{
"status": result.Status,
"raw_data": result.RawData,
"err_msg": result.ErrorMsg,
}
if result.FileURL != "" {
// 下载文件到本地
url, err := s.uploadManager.GetUploadHandler().PutUrlFile(result.FileURL, getFileExt(result.FileURL), false)
if err != nil {
logger.Errorf("failed to download file: %v", err)
continue
}
updates["file_url"] = url
logger.Infof("download file: %s", url)
}
if result.PreviewURL != "" {
url, err := s.uploadManager.GetUploadHandler().PutUrlFile(result.PreviewURL, getFileExt(result.PreviewURL), false)
if err != nil {
logger.Errorf("failed to download preview image: %v", err)
continue
}
updates["preview_url"] = url
logger.Infof("download preview image: %s", url)
}
s.db.Model(&model.AI3DJob{}).Where("id = ?", job.Id).Updates(updates)
}
}
}
@@ -193,12 +247,12 @@ func (s *Service) pollJobStatus(job *model.AI3DJob) {
// queryJobStatus 查询任务状态
func (s *Service) queryJobStatus(job *model.AI3DJob) (*types.AI3DJobResult, error) {
switch job.Type {
case "tencent":
case types.AI3DTaskTypeTencent:
if s.tencentClient == nil {
return nil, fmt.Errorf("tencent 3D client not initialized")
}
return s.tencentClient.QueryJob(job.TaskId)
case "gitee":
case types.AI3DTaskTypeGitee:
if s.giteeClient == nil {
return nil, fmt.Errorf("gitee 3D client not initialized")
}
@@ -209,19 +263,12 @@ func (s *Service) queryJobStatus(job *model.AI3DJob) (*types.AI3DJobResult, erro
}
// updateJobStatus 更新任务状态
func (s *Service) updateJobStatus(job *model.AI3DJob, status string, progress int, errMsg string) {
updates := map[string]interface{}{
"status": status,
"progress": progress,
"updated_at": time.Now(),
}
if errMsg != "" {
updates["err_msg"] = errMsg
}
func (s *Service) updateJobStatus(jobId uint, status string, errMsg string) error {
if err := s.db.Model(job).Updates(updates).Error; err != nil {
logger.Errorf("failed to update job status: %v", err)
}
return s.db.Model(model.AI3DJob{}).Where("id = ?", jobId).Updates(map[string]any{
"status": status,
"err_msg": errMsg,
}).Error
}
// GetJobList 获取任务列表
@@ -254,10 +301,10 @@ func (s *Service) GetJobList(userId uint, page, pageSize int) (*vo.Page, error)
Model: job.Model,
Status: job.Status,
ErrMsg: job.ErrMsg,
Params: job.Params,
CreatedAt: job.CreatedAt.Unix(),
UpdatedAt: job.UpdatedAt.Unix(),
}
_ = utils.JsonDecode(job.Params, &jobVO.Params)
jobList = append(jobList, jobVO)
}
@@ -269,29 +316,34 @@ func (s *Service) GetJobList(userId uint, page, pageSize int) (*vo.Page, error)
}, nil
}
// GetJobById 根据ID获取任务
func (s *Service) GetJobById(id uint) (*model.AI3DJob, error) {
var job model.AI3DJob
if err := s.db.Where("id = ?", id).First(&job).Error; err != nil {
return nil, err
}
return &job, nil
}
// DeleteJob 删除任务
func (s *Service) DeleteJob(id uint, userId uint) error {
func (s *Service) DeleteUserJob(id uint, userId uint) error {
var job model.AI3DJob
if err := s.db.Where("id = ? AND user_id = ?", id, userId).First(&job).Error; err != nil {
err := s.db.Where("id = ?", id).Where("user_id = ?", userId).First(&job).Error
if err != nil {
return err
}
// 如果任务已完成,退还算力
if job.Status == types.AI3DJobStatusCompleted {
// TODO: 实现算力退还逻辑
logger2.GetLogger().Infof("should refund power %d for user %d", job.Power, userId)
tx := s.db.Begin()
err = tx.Delete(&job).Error
if err != nil {
return err
}
return s.db.Delete(&job).Error
// 失败的任务要退回算力
if job.Status == types.AI3DJobStatusFailed {
err = s.userService.IncreasePower(userId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.Model,
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
})
if err != nil {
tx.Rollback()
return err
}
}
tx.Commit()
return nil
}
// GetSupportedModels 获取支持的模型列表
@@ -316,12 +368,15 @@ func (s *Service) UpdateConfig(config types.AI3DConfig) {
}
}
// getString 从map中获取字符串值
func (s *Service) getString(params map[string]interface{}, key string) string {
if val, ok := params[key]; ok {
if str, ok := val.(string); ok {
return str
}
// getFileExt 获取文件扩展名
func getFileExt(fileURL string) string {
parse, err := url.Parse(fileURL)
if err != nil {
return ""
}
return ""
ext := filepath.Ext(parse.Path)
if ext == "" {
return ".glb"
}
return ext
}