AI3D 功能完成

This commit is contained in:
GeekMaster
2025-09-04 18:36:49 +08:00
parent 53866d1461
commit 52d297624d
30 changed files with 829 additions and 969 deletions

View File

@@ -31,12 +31,12 @@ const (
// AI3DJobResult 3D任务结果
type AI3DJobResult struct {
JobId string `json:"job_id"` // 任务ID
TaskId string `json:"task_id"` // 任务ID
Status string `json:"status"` // 任务状态
Progress int `json:"progress"` // 任务进度 (0-100)
FileURL string `json:"file_url"` // 3D模型文件URL
PreviewURL string `json:"preview_url"` // 预览图片URL
ErrorMsg string `json:"error_msg"` // 错误信息
RawData string `json:"raw_data"` // 原始数据
}
// AI3DModel 3D模型配置
@@ -60,6 +60,6 @@ type AI3DJobRequest struct {
const (
AI3DJobStatusPending = "pending" // 等待中
AI3DJobStatusProcessing = "processing" // 处理中
AI3DJobStatusCompleted = "completed" // 已完成
AI3DJobStatusSuccess = "success" // 已完成
AI3DJobStatusFailed = "failed" // 失败
)

View File

@@ -96,6 +96,9 @@ func (h *AI3DHandler) GetJobList(c *gin.Context) {
if err != nil {
continue
}
utils.JsonDecode(job.Params, &jobVo.Params)
jobVo.CreatedAt = job.CreatedAt.Unix()
jobVo.UpdatedAt = job.UpdatedAt.Unix()
jobList = append(jobList, jobVo)
}
@@ -128,6 +131,9 @@ func (h *AI3DHandler) GetJobDetail(c *gin.Context) {
resp.ERROR(c, "获取任务详情失败")
return
}
utils.JsonDecode(job.Params, &jobVo.Params)
jobVo.CreatedAt = job.CreatedAt.Unix()
jobVo.UpdatedAt = job.UpdatedAt.Unix()
resp.SUCCESS(c, jobVo)
}
@@ -167,14 +173,14 @@ func (h *AI3DHandler) GetStats(c *gin.Context) {
var stats struct {
Pending int64 `json:"pending"`
Processing int64 `json:"processing"`
Completed int64 `json:"completed"`
Success int64 `json:"success"`
Failed int64 `json:"failed"`
}
// 统计各状态的任务数量
h.db.Model(&model.AI3DJob{}).Where("status = ?", "pending").Count(&stats.Pending)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "processing").Count(&stats.Processing)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "completed").Count(&stats.Completed)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "success").Count(&stats.Success)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "failed").Count(&stats.Failed)
resp.SUCCESS(c, stats)

View File

@@ -5,7 +5,6 @@ import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/ai3d"
"geekai/store/model"
"geekai/store/vo"
@@ -19,14 +18,12 @@ import (
type AI3DHandler struct {
BaseHandler
service *ai3d.Service
userService *service.UserService
service *ai3d.Service
}
func NewAI3DHandler(app *core.AppServer, db *gorm.DB, service *ai3d.Service, userService *service.UserService) *AI3DHandler {
func NewAI3DHandler(app *core.AppServer, db *gorm.DB, service *ai3d.Service) *AI3DHandler {
return &AI3DHandler{
service: service,
userService: userService,
service: service,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -47,30 +44,14 @@ func (h *AI3DHandler) RegisterRoutes() {
group.POST("generate", h.Generate)
group.GET("jobs", h.JobList)
group.GET("jobs/mock", h.ListMock) // 演示数据接口
group.GET("job/:id", h.JobDetail)
group.GET("job/delete", h.DeleteJob)
group.GET("download/:id", h.Download)
}
}
// Generate 创建3D生成任务
func (h *AI3DHandler) Generate(c *gin.Context) {
var request struct {
// 通用参数
Type types.AI3DTaskType `json:"type" binding:"required"` // API类型 (tencent/gitee)
Model string `json:"model" binding:"required"` // 3D模型类型
Prompt string `json:"prompt"` // 文本提示词
ImageURL string `json:"image_url"` // 输入图片URL
FileFormat string `json:"file_format"` // 输出文件格式
// 腾讯3d专有参数
EnablePBR bool `json:"enable_pbr"` // 是否开启PBR材质
// Gitee3d专有参数
Texture bool `json:"texture"` // 是否开启纹理
Seed int `json:"seed"` // 随机种子
NumInferenceSteps int `json:"num_inference_steps"` //迭代次数
GuidanceScale float64 `json:"guidance_scale"` //引导系数
OctreeResolution int `json:"octree_resolution"` // 3D 渲染精度越高3D 细节越丰富
}
var request vo.AI3DJobParams
if err := c.ShouldBindJSON(&request); err != nil {
resp.ERROR(c, "参数错误")
return
@@ -90,17 +71,17 @@ func (h *AI3DHandler) Generate(c *gin.Context) {
logger.Infof("request: %+v", request)
// // 获取用户ID
// userId := h.GetLoginUserId(c)
// // 创建任务
// job, err := h.service.CreateJob(uint(userId), request)
// if err != nil {
// resp.ERROR(c, fmt.Sprintf("创建任务失败: %v", err))
// return
// }
// 获取用户ID
userId := h.GetLoginUserId(c)
// 创建任务
job, err := h.service.CreateJob(uint(userId), request)
if err != nil {
resp.ERROR(c, fmt.Sprintf("创建任务失败: %v", err))
return
}
resp.SUCCESS(c, gin.H{
"job_id": 0,
"job_id": job.Id,
"message": "任务创建成功",
})
}
@@ -132,133 +113,24 @@ func (h *AI3DHandler) JobList(c *gin.Context) {
resp.SUCCESS(c, jobList)
}
// JobDetail 获取任务详情
func (h *AI3DHandler) JobDetail(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "任务ID格式错误")
return
}
job, err := h.service.GetJobById(uint(id))
if err != nil {
resp.ERROR(c, "任务不存在")
return
}
// 检查权限
if job.UserId != uint(userId) {
resp.ERROR(c, "无权限访问此任务")
return
}
// 转换为VO
jobVO := vo.AI3DJob{
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
Power: job.Power,
TaskId: job.TaskId,
FileURL: job.FileURL,
PreviewURL: job.PreviewURL,
Model: job.Model,
Status: job.Status,
ErrMsg: job.ErrMsg,
Params: job.Params,
CreatedAt: job.CreatedAt.Unix(),
UpdatedAt: job.UpdatedAt.Unix(),
}
resp.SUCCESS(c, jobVO)
}
// DeleteJob 删除任务
func (h *AI3DHandler) DeleteJob(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := c.Query("id")
if id == "" {
id := h.GetInt(c, "id", 0)
if id == 0 {
resp.ERROR(c, "任务ID不能为空")
return
}
var job model.AI3DJob
err := h.DB.Where("id = ?", id).Where("user_id = ?", userId).First(&job).Error
err := h.service.DeleteUserJob(uint(id), uint(userId))
if err != nil {
resp.ERROR(c, err.Error())
resp.ERROR(c, "删除任务失败")
return
}
err = h.DB.Delete(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 失败的任务要退回算力
if job.Status == types.AI3DJobStatusFailed {
err = h.userService.IncreasePower(userId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.Model,
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c, gin.H{"message": "删除成功"})
}
// Download 下载3D模型
func (h *AI3DHandler) Download(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "任务ID格式错误")
return
}
job, err := h.service.GetJobById(uint(id))
if err != nil {
resp.ERROR(c, "任务不存在")
return
}
// 检查权限
if job.UserId != uint(userId) {
resp.ERROR(c, "无权限访问此任务")
return
}
// 检查任务状态
if job.Status != types.AI3DJobStatusCompleted {
resp.ERROR(c, "任务尚未完成")
return
}
if job.FileURL == "" {
resp.ERROR(c, "模型文件不存在")
return
}
// 重定向到下载链接
c.Redirect(302, job.FileURL)
}
// GetConfigs 获取3D生成配置
func (h *AI3DHandler) GetConfigs(c *gin.Context) {
var config model.Config
@@ -281,8 +153,6 @@ func (h *AI3DHandler) GetConfigs(c *gin.Context) {
config3d.Tencent.Models = models["tencent"]
}
logger.Info("config3d: ", config3d)
resp.SUCCESS(c, config3d)
}
@@ -299,9 +169,9 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
FileURL: "https://img.r9it.com/R03TQZ7PZ386RGL7PTMNGFOHAJW15WYF.glb",
PreviewURL: "/static/upload/2025/9/1756873317505073.png",
Model: "gitee-3d-v1",
Status: types.AI3DJobStatusCompleted,
Status: types.AI3DJobStatusSuccess,
ErrMsg: "",
Params: `{"prompt":"一只可爱的小猫","image_url":"","texture":true,"seed":42}`,
Params: vo.AI3DJobParams{Prompt: "一只可爱的小猫", ImageURL: "", Texture: true, Seed: 42},
CreatedAt: 1704067200, // 2024-01-01 00:00:00
UpdatedAt: 1704067800, // 2024-01-01 00:10:00
},
@@ -316,7 +186,7 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
Model: "tencent-3d-v2",
Status: types.AI3DJobStatusProcessing,
ErrMsg: "",
Params: `{"prompt":"一个现代建筑模型","image_url":"","enable_pbr":true}`,
Params: vo.AI3DJobParams{Prompt: "一个现代建筑模型", ImageURL: "", EnablePBR: true},
CreatedAt: 1704070800, // 2024-01-01 01:00:00
UpdatedAt: 1704070800, // 2024-01-01 01:00:00
},
@@ -331,7 +201,7 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
Model: "gitee-3d-v1",
Status: types.AI3DJobStatusPending,
ErrMsg: "",
Params: `{"prompt":"一辆跑车模型","image_url":"https://example.com/car.jpg","texture":false}`,
Params: vo.AI3DJobParams{Prompt: "一辆跑车模型", ImageURL: "https://example.com/car.jpg", Texture: false},
CreatedAt: 1704074400, // 2024-01-01 02:00:00
UpdatedAt: 1704074400, // 2024-01-01 02:00:00
},
@@ -346,7 +216,7 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
Model: "tencent-3d-v1",
Status: types.AI3DJobStatusFailed,
ErrMsg: "模型生成失败:输入图片质量不符合要求",
Params: `{"prompt":"一个机器人模型","image_url":"https://example.com/robot.jpg","enable_pbr":false}`,
Params: vo.AI3DJobParams{Prompt: "一个机器人模型", ImageURL: "https://example.com/robot.jpg", EnablePBR: false},
CreatedAt: 1704078000, // 2024-01-01 03:00:00
UpdatedAt: 1704078600, // 2024-01-01 03:10:00
},
@@ -359,9 +229,9 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
FileURL: "https://ai.gitee.com/a8c1af8e-26e9-4ca6-aa5c-6d4ba86bfdac",
PreviewURL: "https://ai.gitee.com/a8c1af8e-26e9-4ca6-aa5c-6d4ba86bfdac",
Model: "gitee-3d-v2",
Status: types.AI3DJobStatusCompleted,
Status: types.AI3DJobStatusSuccess,
ErrMsg: "",
Params: `{"prompt":"一个复杂的机械装置","image_url":"","texture":true,"octree_resolution":512}`,
Params: vo.AI3DJobParams{Prompt: "一个复杂的机械装置", ImageURL: "", Texture: true, OctreeResolution: 512},
CreatedAt: 1704081600, // 2024-01-01 04:00:00
UpdatedAt: 1704082200, // 2024-01-01 04:10:00
},
@@ -376,17 +246,17 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
Model: "tencent-3d-v2",
Status: types.AI3DJobStatusProcessing,
ErrMsg: "",
Params: `{"prompt":"一个科幻飞船","image_url":"","enable_pbr":true}`,
Params: vo.AI3DJobParams{Prompt: "一个科幻飞船", ImageURL: "", EnablePBR: true},
CreatedAt: 1704085200, // 2024-01-01 05:00:00
UpdatedAt: 1704085200, // 2024-01-01 05:00:00
},
}
// 创建分页响应
mockResponse := vo.ThreeDJobList{
mockResponse := vo.Page{
Page: 1,
PageSize: 10,
Total: len(mockJobs),
Total: int64(len(mockJobs)),
Items: mockJobs,
}

View File

@@ -16,36 +16,43 @@ type Gitee3DClient struct {
}
type Gitee3DParams struct {
Prompt string `json:"prompt"` // 文本提示词
ImageURL string `json:"image_url"` // 输入图片URL
ResultFormat string `json:"result_format"` // 输出格式
Model string `json:"model"` // 模型名称
FileFormat string `json:"file_format,omitempty"` // 文件格式(Step1X-3D、Hi3DGen模型适用),支持 glb 和 stl
Type string `json:"type,omitempty"` // 输出格式(Hunyuan3D-2模型适用)
ImageURL string `json:"image_url"` // 输入图片URL
Texture bool `json:"texture,omitempty"` // 是否开启纹理
Seed int `json:"seed,omitempty"` // 随机种子
NumInferenceSteps int `json:"num_inference_steps,omitempty"` //迭代次数
GuidanceScale float64 `json:"guidance_scale,omitempty"` //引导系数
OctreeResolution int `json:"octree_resolution,omitempty"` // 3D 渲染精度越高3D 细节越丰富
}
type Gitee3DResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
TaskID string `json:"task_id"`
} `json:"data"`
TaskID string `json:"task_id"`
Output struct {
FileURL string `json:"file_url,omitempty"`
PreviewURL string `json:"preview_url,omitempty"`
} `json:"output"`
Status string `json:"status"`
CreatedAt any `json:"created_at"`
StartedAt any `json:"started_at"`
CompletedAt any `json:"completed_at"`
Urls struct {
Get string `json:"get"`
Cancel string `json:"cancel"`
} `json:"urls"`
}
type Gitee3DQueryResponse struct {
Code int `json:"code"`
type GiteeErrorResponse struct {
Error int `json:"error"`
Message string `json:"message"`
Data struct {
Status string `json:"status"`
Progress int `json:"progress"`
ResultURL string `json:"result_url"`
PreviewURL string `json:"preview_url"`
ErrorMsg string `json:"error_msg"`
} `json:"data"`
}
func NewGitee3DClient(sysConfig *types.SystemConfig) *Gitee3DClient {
return &Gitee3DClient{
httpClient: req.C().SetTimeout(time.Minute * 3),
config: sysConfig.AI3D.Gitee,
apiURL: "https://ai.gitee.com/v1/async/image-to-3d",
apiURL: "https://ai.gitee.com/v1",
}
}
@@ -53,73 +60,62 @@ func (c *Gitee3DClient) UpdateConfig(config types.Gitee3DConfig) {
c.config = config
}
func (c *Gitee3DClient) GetConfig() *types.Gitee3DConfig {
return &c.config
}
// SubmitJob 提交3D生成任务
func (c *Gitee3DClient) SubmitJob(params Gitee3DParams) (string, error) {
requestBody := map[string]any{
"prompt": params.Prompt,
"image_url": params.ImageURL,
"result_format": params.ResultFormat,
}
var giteeResp Gitee3DResponse
response, err := c.httpClient.R().
SetHeader("Authorization", "Bearer "+c.config.APIKey).
SetHeader("Content-Type", "application/json").
SetBody(requestBody).
SetBody(params).
SetSuccessResult(&giteeResp).
Post(c.apiURL + "/async/image-to-3d")
if err != nil {
return "", fmt.Errorf("failed to submit gitee 3D job: %v", err)
}
var giteeResp Gitee3DResponse
if err := json.Unmarshal(response.Bytes(), &giteeResp); err != nil {
return "", fmt.Errorf("failed to parse gitee response: %v", err)
if giteeResp.TaskID == "" {
var giteeErr GiteeErrorResponse
_ = json.Unmarshal(response.Bytes(), &giteeErr)
return "", fmt.Errorf("no task ID returned from gitee 3D API: %s", giteeErr.Message)
}
if giteeResp.Code != 0 {
return "", fmt.Errorf("gitee API error: %s", giteeResp.Message)
}
if giteeResp.Data.TaskID == "" {
return "", fmt.Errorf("no task ID returned from gitee 3D API")
}
return giteeResp.Data.TaskID, nil
return giteeResp.TaskID, nil
}
// QueryJob 查询任务状态
func (c *Gitee3DClient) QueryJob(taskId string) (*types.AI3DJobResult, error) {
var giteeResp Gitee3DResponse
apiURL := fmt.Sprintf("%s/task/%s", c.apiURL, taskId)
response, err := c.httpClient.R().
SetHeader("Authorization", "Bearer "+c.config.APIKey).
Get(fmt.Sprintf("%s/task/%s/get", c.apiURL, taskId))
SetSuccessResult(&giteeResp).
Get(apiURL)
if err != nil {
return nil, fmt.Errorf("failed to query gitee 3D job: %v", err)
}
var giteeResp Gitee3DQueryResponse
if err := json.Unmarshal(response.Bytes(), &giteeResp); err != nil {
return nil, fmt.Errorf("failed to parse gitee query response: %v", err)
}
if giteeResp.Code != 0 {
return nil, fmt.Errorf("gitee API error: %s", giteeResp.Message)
}
result := &types.AI3DJobResult{
JobId: taskId,
Status: c.convertStatus(giteeResp.Data.Status),
Progress: giteeResp.Data.Progress,
TaskId: taskId,
Status: c.convertStatus(giteeResp.Status),
}
// 根据状态设置结果
switch giteeResp.Data.Status {
case "completed":
result.FileURL = giteeResp.Data.ResultURL
result.PreviewURL = giteeResp.Data.PreviewURL
case "failed":
result.ErrorMsg = giteeResp.Data.ErrorMsg
if giteeResp.TaskID == "" {
var giteeErr GiteeErrorResponse
_ = json.Unmarshal(response.Bytes(), &giteeErr)
result.ErrorMsg = giteeErr.Message
} else if giteeResp.Status == "success" {
result.FileURL = giteeResp.Output.FileURL
}
result.RawData = string(response.Bytes())
logger.Debugf("gitee 3D job response: %+v", result)
return result, nil
}
@@ -127,13 +123,13 @@ func (c *Gitee3DClient) QueryJob(taskId string) (*types.AI3DJobResult, error) {
// convertStatus 转换Gitee状态到系统状态
func (c *Gitee3DClient) convertStatus(giteeStatus string) string {
switch giteeStatus {
case "pending":
case "waiting":
return types.AI3DJobStatusPending
case "processing":
case "in_progress":
return types.AI3DJobStatusProcessing
case "completed":
return types.AI3DJobStatusCompleted
case "failed":
case "success":
return types.AI3DJobStatusSuccess
case "failure", "cancelled":
return types.AI3DJobStatusFailed
default:
return types.AI3DJobStatusPending

View File

@@ -1,13 +1,18 @@
package ai3d
import (
"encoding/json"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/go-redis/redis/v8"
@@ -22,52 +27,81 @@ type Service struct {
taskQueue *store.RedisQueue
tencentClient *Tencent3DClient
giteeClient *Gitee3DClient
userService *service.UserService
uploadManager *oss.UploaderManager
}
// NewService 创建3D生成服务
func NewService(db *gorm.DB, redisCli *redis.Client, tencentClient *Tencent3DClient, giteeClient *Gitee3DClient) *Service {
func NewService(db *gorm.DB, redisCli *redis.Client, tencentClient *Tencent3DClient, giteeClient *Gitee3DClient, userService *service.UserService, uploadManager *oss.UploaderManager) *Service {
return &Service{
db: db,
taskQueue: store.NewRedisQueue("3D_Task_Queue", redisCli),
tencentClient: tencentClient,
giteeClient: giteeClient,
userService: userService,
uploadManager: uploadManager,
}
}
// CreateJob 创建3D生成任务
func (s *Service) CreateJob(userId uint, request vo.AI3DJobCreate) (*model.AI3DJob, error) {
// 创建任务记录
job := &model.AI3DJob{
UserId: userId,
Type: request.Type,
Power: request.Power,
Model: request.Model,
Status: types.AI3DJobStatusPending,
func (s *Service) CreateJob(userId uint, request vo.AI3DJobParams) (*model.AI3DJob, error) {
switch request.Type {
case types.AI3DTaskTypeGitee:
if s.giteeClient == nil {
return nil, fmt.Errorf("模力方舟 3D 服务未初始化")
}
if !s.giteeClient.GetConfig().Enabled {
return nil, fmt.Errorf("模力方舟 3D 服务未启用")
}
case types.AI3DTaskTypeTencent:
if s.tencentClient == nil {
return nil, fmt.Errorf("腾讯云 3D 服务未初始化")
}
if !s.tencentClient.GetConfig().Enabled {
return nil, fmt.Errorf("腾讯云 3D 服务未启用")
}
default:
return nil, fmt.Errorf("不支持的 3D 服务类型: %s", request.Type)
}
// 序列化参数
params := map[string]any{
"prompt": request.Prompt,
"image_url": request.ImageURL,
"model": request.Model,
"power": request.Power,
// 创建任务记录
job := &model.AI3DJob{
UserId: userId,
Type: request.Type,
Power: request.Power,
Model: request.Model,
Status: types.AI3DJobStatusPending,
PreviewURL: request.ImageURL,
}
paramsJSON, _ := json.Marshal(params)
job.Params = string(paramsJSON)
job.Params = utils.JsonEncode(request)
// 保存到数据库
if err := s.db.Create(job).Error; err != nil {
return nil, fmt.Errorf("failed to create 3D job: %v", err)
}
// 更新用户算力
err := s.userService.DecreasePower(userId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: job.Model,
Remark: fmt.Sprintf("创建3D任务消耗%d算力", job.Power),
})
if err != nil {
return nil, fmt.Errorf("failed to update user power: %v", err)
}
// 将任务添加到队列
s.PushTask(job)
request.JobId = job.Id
s.PushTask(request)
return job, nil
}
// PushTask 将任务添加到队列
func (s *Service) PushTask(job *model.AI3DJob) {
func (s *Service) PushTask(job vo.AI3DJobParams) {
logger.Infof("add a new 3D task to the queue: %+v", job)
if err := s.taskQueue.RPush(job); err != nil {
logger.Errorf("push 3D task to queue failed: %v", err)
@@ -76,72 +110,70 @@ func (s *Service) PushTask(job *model.AI3DJob) {
// Run 启动任务处理器
func (s *Service) Run() {
// 将数据库中未完成的任务加载到队列
var jobs []model.AI3DJob
s.db.Where("status IN ?", []string{types.AI3DJobStatusPending, types.AI3DJobStatusProcessing}).Find(&jobs)
for _, job := range jobs {
s.PushTask(&job)
}
logger.Info("Starting 3D job consumer...")
go func() {
for {
var job model.AI3DJob
err := s.taskQueue.LPop(&job)
var params vo.AI3DJobParams
err := s.taskQueue.LPop(&params)
if err != nil {
logger.Errorf("taking 3D task with error: %v", err)
continue
}
logger.Infof("handle a new 3D task: %+v", job)
logger.Infof("handle a new 3D task: %+v", params)
go func() {
if err := s.processJob(&job); err != nil {
if err := s.processJob(&params); err != nil {
logger.Errorf("error processing 3D job: %v", err)
s.updateJobStatus(&job, types.AI3DJobStatusFailed, 0, err.Error())
s.updateJobStatus(params.JobId, types.AI3DJobStatusFailed, err.Error())
}
}()
}
}()
go s.pollJobStatus()
}
// processJob 处理3D任务
func (s *Service) processJob(job *model.AI3DJob) error {
func (s *Service) processJob(params *vo.AI3DJobParams) error {
// 更新状态为处理中
s.updateJobStatus(job, types.AI3DJobStatusProcessing, 10, "")
// 解析参数
var params map[string]any
if err := json.Unmarshal([]byte(job.Params), &params); err != nil {
return fmt.Errorf("failed to parse job params: %v", err)
}
s.updateJobStatus(params.JobId, types.AI3DJobStatusProcessing, "")
var taskId string
var err error
// 根据类型选择客户端
switch job.Type {
case "tencent":
switch params.Type {
case types.AI3DTaskTypeTencent:
if s.tencentClient == nil {
return fmt.Errorf("tencent 3D client not initialized")
}
tencentParams := Tencent3DParams{
Prompt: s.getString(params, "prompt"),
ImageURL: s.getString(params, "image_url"),
ResultFormat: job.Model,
EnablePBR: false,
Prompt: params.Prompt,
ImageURL: params.ImageURL,
ResultFormat: params.FileFormat,
EnablePBR: params.EnablePBR,
}
taskId, err = s.tencentClient.SubmitJob(tencentParams)
case "gitee":
case types.AI3DTaskTypeGitee:
if s.giteeClient == nil {
return fmt.Errorf("gitee 3D client not initialized")
}
giteeParams := Gitee3DParams{
Prompt: s.getString(params, "prompt"),
ImageURL: s.getString(params, "image_url"),
ResultFormat: job.Model,
Model: params.Model,
Texture: params.Texture,
Seed: params.Seed,
NumInferenceSteps: params.NumInferenceSteps,
GuidanceScale: params.GuidanceScale,
OctreeResolution: params.OctreeResolution,
ImageURL: params.ImageURL,
}
if params.Model == "Hunyuan3D-2" {
giteeParams.Type = strings.ToLower(params.FileFormat)
} else {
giteeParams.FileFormat = strings.ToLower(params.FileFormat)
}
taskId, err = s.giteeClient.SubmitJob(giteeParams)
default:
return fmt.Errorf("unsupported 3D API type: %s", job.Type)
return fmt.Errorf("unsupported 3D API type: %s", params.Type)
}
if err != nil {
@@ -149,43 +181,65 @@ func (s *Service) processJob(job *model.AI3DJob) error {
}
// 更新任务ID
job.TaskId = taskId
s.db.Model(job).Update("task_id", taskId)
// 开始轮询任务状态
go s.pollJobStatus(job)
s.db.Model(model.AI3DJob{}).Where("id = ?", params.JobId).Update("task_id", taskId)
return nil
}
// pollJobStatus 轮询任务状态
func (s *Service) pollJobStatus(job *model.AI3DJob) {
func (s *Service) pollJobStatus() {
// 10秒轮询一次
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
result, err := s.queryJobStatus(job)
for range ticker.C {
var jobs []model.AI3DJob
s.db.Where("status IN (?)", []string{types.AI3DJobStatusProcessing, types.AI3DJobStatusPending}).Find(&jobs)
if len(jobs) == 0 {
logger.Debug("no 3D jobs to poll, sleep 10s")
continue
}
for _, job := range jobs {
// 15 分钟超时
if job.CreatedAt.Before(time.Now().Add(-20 * time.Minute)) {
s.updateJobStatus(job.Id, types.AI3DJobStatusFailed, "task timeout")
continue
}
result, err := s.queryJobStatus(&job)
if err != nil {
logger.Errorf("failed to query job status: %v", err)
continue
}
// 更新进度
s.updateJobStatus(job, result.Status, result.Progress, result.ErrorMsg)
// 如果任务完成或失败,停止轮询
if result.Status == types.AI3DJobStatusCompleted || result.Status == types.AI3DJobStatusFailed {
if result.Status == types.AI3DJobStatusCompleted {
// 更新结果文件URL
s.db.Model(job).Updates(map[string]interface{}{
"img_url": result.FileURL,
"preview_url": result.PreviewURL,
})
}
return
updates := map[string]any{
"status": result.Status,
"raw_data": result.RawData,
"err_msg": result.ErrorMsg,
}
if result.FileURL != "" {
// 下载文件到本地
url, err := s.uploadManager.GetUploadHandler().PutUrlFile(result.FileURL, getFileExt(result.FileURL), false)
if err != nil {
logger.Errorf("failed to download file: %v", err)
continue
}
updates["file_url"] = url
logger.Infof("download file: %s", url)
}
if result.PreviewURL != "" {
url, err := s.uploadManager.GetUploadHandler().PutUrlFile(result.PreviewURL, getFileExt(result.PreviewURL), false)
if err != nil {
logger.Errorf("failed to download preview image: %v", err)
continue
}
updates["preview_url"] = url
logger.Infof("download preview image: %s", url)
}
s.db.Model(&model.AI3DJob{}).Where("id = ?", job.Id).Updates(updates)
}
}
}
@@ -193,12 +247,12 @@ func (s *Service) pollJobStatus(job *model.AI3DJob) {
// queryJobStatus 查询任务状态
func (s *Service) queryJobStatus(job *model.AI3DJob) (*types.AI3DJobResult, error) {
switch job.Type {
case "tencent":
case types.AI3DTaskTypeTencent:
if s.tencentClient == nil {
return nil, fmt.Errorf("tencent 3D client not initialized")
}
return s.tencentClient.QueryJob(job.TaskId)
case "gitee":
case types.AI3DTaskTypeGitee:
if s.giteeClient == nil {
return nil, fmt.Errorf("gitee 3D client not initialized")
}
@@ -209,19 +263,12 @@ func (s *Service) queryJobStatus(job *model.AI3DJob) (*types.AI3DJobResult, erro
}
// updateJobStatus 更新任务状态
func (s *Service) updateJobStatus(job *model.AI3DJob, status string, progress int, errMsg string) {
updates := map[string]interface{}{
"status": status,
"progress": progress,
"updated_at": time.Now(),
}
if errMsg != "" {
updates["err_msg"] = errMsg
}
func (s *Service) updateJobStatus(jobId uint, status string, errMsg string) error {
if err := s.db.Model(job).Updates(updates).Error; err != nil {
logger.Errorf("failed to update job status: %v", err)
}
return s.db.Model(model.AI3DJob{}).Where("id = ?", jobId).Updates(map[string]any{
"status": status,
"err_msg": errMsg,
}).Error
}
// GetJobList 获取任务列表
@@ -254,10 +301,10 @@ func (s *Service) GetJobList(userId uint, page, pageSize int) (*vo.Page, error)
Model: job.Model,
Status: job.Status,
ErrMsg: job.ErrMsg,
Params: job.Params,
CreatedAt: job.CreatedAt.Unix(),
UpdatedAt: job.UpdatedAt.Unix(),
}
_ = utils.JsonDecode(job.Params, &jobVO.Params)
jobList = append(jobList, jobVO)
}
@@ -269,29 +316,34 @@ func (s *Service) GetJobList(userId uint, page, pageSize int) (*vo.Page, error)
}, nil
}
// GetJobById 根据ID获取任务
func (s *Service) GetJobById(id uint) (*model.AI3DJob, error) {
var job model.AI3DJob
if err := s.db.Where("id = ?", id).First(&job).Error; err != nil {
return nil, err
}
return &job, nil
}
// DeleteJob 删除任务
func (s *Service) DeleteJob(id uint, userId uint) error {
func (s *Service) DeleteUserJob(id uint, userId uint) error {
var job model.AI3DJob
if err := s.db.Where("id = ? AND user_id = ?", id, userId).First(&job).Error; err != nil {
err := s.db.Where("id = ?", id).Where("user_id = ?", userId).First(&job).Error
if err != nil {
return err
}
// 如果任务已完成,退还算力
if job.Status == types.AI3DJobStatusCompleted {
// TODO: 实现算力退还逻辑
logger2.GetLogger().Infof("should refund power %d for user %d", job.Power, userId)
tx := s.db.Begin()
err = tx.Delete(&job).Error
if err != nil {
return err
}
return s.db.Delete(&job).Error
// 失败的任务要退回算力
if job.Status == types.AI3DJobStatusFailed {
err = s.userService.IncreasePower(userId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.Model,
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
})
if err != nil {
tx.Rollback()
return err
}
}
tx.Commit()
return nil
}
// GetSupportedModels 获取支持的模型列表
@@ -316,12 +368,15 @@ func (s *Service) UpdateConfig(config types.AI3DConfig) {
}
}
// getString 从map中获取字符串值
func (s *Service) getString(params map[string]interface{}, key string) string {
if val, ok := params[key]; ok {
if str, ok := val.(string); ok {
return str
}
// getFileExt 获取文件扩展名
func getFileExt(fileURL string) string {
parse, err := url.Parse(fileURL)
if err != nil {
return ""
}
return ""
ext := filepath.Ext(parse.Path)
if ext == "" {
return ".glb"
}
return ext
}

View File

@@ -3,6 +3,7 @@ package ai3d
import (
"fmt"
"geekai/core/types"
"geekai/utils"
tencent3d "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ai3d/v20250513"
tencentcloud "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
@@ -58,6 +59,10 @@ func (c *Tencent3DClient) UpdateConfig(config types.Tencent3DConfig) error {
return nil
}
func (c *Tencent3DClient) GetConfig() *types.Tencent3DConfig {
return &c.config
}
// SubmitJob 提交3D生成任务
func (c *Tencent3DClient) SubmitJob(params Tencent3DParams) (string, error) {
request := tencent3d.NewSubmitHunyuanTo3DJobRequest()
@@ -111,42 +116,39 @@ func (c *Tencent3DClient) QueryJob(jobId string) (*types.AI3DJobResult, error) {
}
result := &types.AI3DJobResult{
JobId: jobId,
Status: *response.Response.Status,
Progress: 0,
TaskId: jobId,
}
// 根据状态设置进度
switch *response.Response.Status {
case "WAIT":
result.Status = "pending"
result.Progress = 10
result.Status = types.AI3DJobStatusPending
case "RUN":
result.Status = "processing"
result.Progress = 50
result.Status = types.AI3DJobStatusProcessing
case "DONE":
result.Status = "completed"
result.Progress = 100
result.Status = types.AI3DJobStatusSuccess
// 处理结果文件
if len(response.Response.ResultFile3Ds) > 0 {
for _, file := range response.Response.ResultFile3Ds {
if file.Url != nil {
result.FileURL = *file.Url
}
if file.PreviewImageUrl != nil {
result.PreviewURL = *file.PreviewImageUrl
}
// TODO 取第一个文件
// 取第一个文件
file := response.Response.ResultFile3Ds[0]
if file.Url != nil {
result.FileURL = *file.Url
}
if file.PreviewImageUrl != nil {
result.PreviewURL = *file.PreviewImageUrl
}
}
case "FAIL":
result.Status = "failed"
result.Progress = 0
result.Status = types.AI3DJobStatusFailed
if response.Response.ErrorMessage != nil {
result.ErrorMsg = *response.Response.ErrorMessage
}
}
logger.Debugf("tencent 3D job result: %+v", *response.Response)
result.RawData = utils.JsonEncode(response.Response)
return result, nil
}

View File

@@ -32,11 +32,9 @@ func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*A
s := &AliYunOss{
proxyURL: appConfig.ProxyURL,
}
if sysConfig.OSS.Active == AliYun {
err := s.UpdateConfig(sysConfig.OSS.AliYun)
if err != nil {
logger.Errorf("阿里云OSS初始化失败: %v", err)
}
err := s.UpdateConfig(sysConfig.OSS.AliYun)
if err != nil {
logger.Warnf("阿里云OSS初始化失败: %v", err)
}
return s, nil

View File

@@ -32,11 +32,9 @@ type MiniOss struct {
func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
s := &MiniOss{proxyURL: appConfig.ProxyURL}
if sysConfig.OSS.Active == Minio {
err := s.UpdateConfig(sysConfig.OSS.Minio)
if err != nil {
logger.Errorf("MinioOSS初始化失败: %v", err)
}
err := s.UpdateConfig(sysConfig.OSS.Minio)
if err != nil {
logger.Warnf("MinioOSS初始化失败: %v", err)
}
return s, nil
}

View File

@@ -37,9 +37,7 @@ func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QiN
s := &QiNiuOss{
proxyURL: appConfig.ProxyURL,
}
if sysConfig.OSS.Active == QiNiu {
s.UpdateConfig(sysConfig.OSS.QiNiu)
}
s.UpdateConfig(sysConfig.OSS.QiNiu)
return s
}

View File

@@ -9,10 +9,10 @@ package oss
import "github.com/gin-gonic/gin"
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
const Local = "local"
const Minio = "minio"
const QiNiu = "qiniu"
const AliYun = "aliyun"
type File struct {
Name string `json:"name"`

View File

@@ -9,7 +9,6 @@ package oss
import (
"geekai/core/types"
"strings"
logger2 "geekai/logger"
)
@@ -28,7 +27,6 @@ func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliy
if sysConfig.OSS.Active == "" {
sysConfig.OSS.Active = Local
}
sysConfig.OSS.Active = strings.ToLower(sysConfig.OSS.Active)
return &UploaderManager{
active: sysConfig.OSS.Active,

View File

@@ -1,21 +1,25 @@
package model
import "time"
import (
"geekai/core/types"
"time"
)
type AI3DJob struct {
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
UserId uint `gorm:"column:user_id;type:int(11);not null;comment:用户ID" json:"user_id"`
Type string `gorm:"column:type;type:varchar(20);not null;comment:API类型 (tencent/gitee)" json:"type"`
Power int `gorm:"column:power;type:int(11);not null;comment:消耗算力" json:"power"`
TaskId string `gorm:"column:task_id;type:varchar(100);comment:第三方任务ID" json:"task_id"`
FileURL string `gorm:"column:file_url;type:varchar(1024);comment:生成的3D模型文件地址" json:"file_url"`
PreviewURL string `gorm:"column:preview_url;type:varchar(1024);comment:预览图片地址" json:"preview_url"`
Model string `gorm:"column:model;type:varchar(50);comment:使用的3D模型类型" json:"model"`
Status string `gorm:"column:status;type:varchar(20);not null;default:pending;comment:任务状态" json:"status"`
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
Params string `gorm:"column:params;type:text;comment:任务参数(JSON格式)" json:"params"`
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
UserId uint `gorm:"column:user_id;type:int(11);not null;comment:用户ID" json:"user_id"`
Type types.AI3DTaskType `gorm:"column:type;type:varchar(20);not null;comment:API类型 (tencent/gitee)" json:"type"`
Power int `gorm:"column:power;type:int(11);not null;comment:消耗算力" json:"power"`
TaskId string `gorm:"column:task_id;type:varchar(100);comment:第三方任务ID" json:"task_id"`
FileURL string `gorm:"column:file_url;type:varchar(1024);comment:生成的3D模型文件地址" json:"file_url"`
PreviewURL string `gorm:"column:preview_url;type:varchar(1024);comment:预览图片地址" json:"preview_url"`
Model string `gorm:"column:model;type:varchar(50);comment:使用的3D模型类型" json:"model"`
Status string `gorm:"column:status;type:varchar(20);not null;default:pending;comment:任务状态" json:"status"`
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
Params string `gorm:"column:params;type:text;comment:任务参数(JSON格式)" json:"params"`
RawData string `gorm:"column:raw_data;type:text;comment:API返回的原始数据" json:"raw_data"`
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
}
func (m *AI3DJob) TableName() string {

View File

@@ -1,33 +1,39 @@
package vo
import "geekai/core/types"
type AI3DJob struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
Type string `json:"type"`
Power int `json:"power"`
TaskId string `json:"task_id"`
FileURL string `json:"file_url"`
PreviewURL string `json:"preview_url"`
Model string `json:"model"`
Status string `json:"status"`
ErrMsg string `json:"err_msg"`
Params string `json:"params"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
Id uint `json:"id"`
UserId uint `json:"user_id"`
Type types.AI3DTaskType `json:"type"`
Power int `json:"power"`
TaskId string `json:"task_id"`
FileURL string `json:"file_url"`
PreviewURL string `json:"preview_url"`
Model string `json:"model"`
Status string `json:"status"`
ErrMsg string `json:"err_msg"`
Params AI3DJobParams `json:"params"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
type AI3DJobCreate struct {
Type string `json:"type" binding:"required"` // API类型 (tencent/gitee)
Model string `json:"model" binding:"required"` // 3D模型类型
Prompt string `json:"prompt"` // 文本提示词
ImageURL string `json:"image_url"` // 输入图片URL
Power int `json:"power" binding:"required"` // 消耗算力
}
type ThreeDJobList struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Total int `json:"total"`
List []AI3DJob `json:"list"`
Items []AI3DJob `json:"items"`
// AI3DJobParams 创建3D任务请求
type AI3DJobParams struct {
// 通用参数
JobId uint `json:"job_id,omitempty"` // 任务ID
Type types.AI3DTaskType `json:"type,omitempty"` // API类型 (tencent/gitee)
Model string `json:"model,omitempty"` // 3D模型类型
Prompt string `json:"prompt,omitempty"` // 文本提示词
ImageURL string `json:"image_url,omitempty"` // 输入图片URL
FileFormat string `json:"file_format,omitempty"` // 输出文件格式
Power int `json:"power,omitempty"` // 消耗算力
// 腾讯3d专有参数
EnablePBR bool `json:"enable_pbr,omitempty"` // 是否开启PBR材质
// Gitee3d专有参数
Texture bool `json:"texture,omitempty"` // 是否开启纹理
Seed int `json:"seed,omitempty"` // 随机种子
NumInferenceSteps int `json:"num_inference_steps,omitempty"` //迭代次数
GuidanceScale float64 `json:"guidance_scale,omitempty"` //引导系数
OctreeResolution int `json:"octree_resolution"` // 3D 渲染精度越高3D 细节越丰富
}