AI3D 功能完成

This commit is contained in:
GeekMaster
2025-09-04 18:36:49 +08:00
parent 53866d1461
commit 52d297624d
30 changed files with 829 additions and 969 deletions

View File

@@ -96,6 +96,9 @@ func (h *AI3DHandler) GetJobList(c *gin.Context) {
if err != nil {
continue
}
utils.JsonDecode(job.Params, &jobVo.Params)
jobVo.CreatedAt = job.CreatedAt.Unix()
jobVo.UpdatedAt = job.UpdatedAt.Unix()
jobList = append(jobList, jobVo)
}
@@ -128,6 +131,9 @@ func (h *AI3DHandler) GetJobDetail(c *gin.Context) {
resp.ERROR(c, "获取任务详情失败")
return
}
utils.JsonDecode(job.Params, &jobVo.Params)
jobVo.CreatedAt = job.CreatedAt.Unix()
jobVo.UpdatedAt = job.UpdatedAt.Unix()
resp.SUCCESS(c, jobVo)
}
@@ -167,14 +173,14 @@ func (h *AI3DHandler) GetStats(c *gin.Context) {
var stats struct {
Pending int64 `json:"pending"`
Processing int64 `json:"processing"`
Completed int64 `json:"completed"`
Success int64 `json:"success"`
Failed int64 `json:"failed"`
}
// 统计各状态的任务数量
h.db.Model(&model.AI3DJob{}).Where("status = ?", "pending").Count(&stats.Pending)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "processing").Count(&stats.Processing)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "completed").Count(&stats.Completed)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "success").Count(&stats.Success)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "failed").Count(&stats.Failed)
resp.SUCCESS(c, stats)

View File

@@ -5,7 +5,6 @@ import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/ai3d"
"geekai/store/model"
"geekai/store/vo"
@@ -19,14 +18,12 @@ import (
type AI3DHandler struct {
BaseHandler
service *ai3d.Service
userService *service.UserService
service *ai3d.Service
}
func NewAI3DHandler(app *core.AppServer, db *gorm.DB, service *ai3d.Service, userService *service.UserService) *AI3DHandler {
func NewAI3DHandler(app *core.AppServer, db *gorm.DB, service *ai3d.Service) *AI3DHandler {
return &AI3DHandler{
service: service,
userService: userService,
service: service,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -47,30 +44,14 @@ func (h *AI3DHandler) RegisterRoutes() {
group.POST("generate", h.Generate)
group.GET("jobs", h.JobList)
group.GET("jobs/mock", h.ListMock) // 演示数据接口
group.GET("job/:id", h.JobDetail)
group.GET("job/delete", h.DeleteJob)
group.GET("download/:id", h.Download)
}
}
// Generate 创建3D生成任务
func (h *AI3DHandler) Generate(c *gin.Context) {
var request struct {
// 通用参数
Type types.AI3DTaskType `json:"type" binding:"required"` // API类型 (tencent/gitee)
Model string `json:"model" binding:"required"` // 3D模型类型
Prompt string `json:"prompt"` // 文本提示词
ImageURL string `json:"image_url"` // 输入图片URL
FileFormat string `json:"file_format"` // 输出文件格式
// 腾讯3d专有参数
EnablePBR bool `json:"enable_pbr"` // 是否开启PBR材质
// Gitee3d专有参数
Texture bool `json:"texture"` // 是否开启纹理
Seed int `json:"seed"` // 随机种子
NumInferenceSteps int `json:"num_inference_steps"` //迭代次数
GuidanceScale float64 `json:"guidance_scale"` //引导系数
OctreeResolution int `json:"octree_resolution"` // 3D 渲染精度越高3D 细节越丰富
}
var request vo.AI3DJobParams
if err := c.ShouldBindJSON(&request); err != nil {
resp.ERROR(c, "参数错误")
return
@@ -90,17 +71,17 @@ func (h *AI3DHandler) Generate(c *gin.Context) {
logger.Infof("request: %+v", request)
// // 获取用户ID
// userId := h.GetLoginUserId(c)
// // 创建任务
// job, err := h.service.CreateJob(uint(userId), request)
// if err != nil {
// resp.ERROR(c, fmt.Sprintf("创建任务失败: %v", err))
// return
// }
// 获取用户ID
userId := h.GetLoginUserId(c)
// 创建任务
job, err := h.service.CreateJob(uint(userId), request)
if err != nil {
resp.ERROR(c, fmt.Sprintf("创建任务失败: %v", err))
return
}
resp.SUCCESS(c, gin.H{
"job_id": 0,
"job_id": job.Id,
"message": "任务创建成功",
})
}
@@ -132,133 +113,24 @@ func (h *AI3DHandler) JobList(c *gin.Context) {
resp.SUCCESS(c, jobList)
}
// JobDetail 获取任务详情
func (h *AI3DHandler) JobDetail(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "任务ID格式错误")
return
}
job, err := h.service.GetJobById(uint(id))
if err != nil {
resp.ERROR(c, "任务不存在")
return
}
// 检查权限
if job.UserId != uint(userId) {
resp.ERROR(c, "无权限访问此任务")
return
}
// 转换为VO
jobVO := vo.AI3DJob{
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
Power: job.Power,
TaskId: job.TaskId,
FileURL: job.FileURL,
PreviewURL: job.PreviewURL,
Model: job.Model,
Status: job.Status,
ErrMsg: job.ErrMsg,
Params: job.Params,
CreatedAt: job.CreatedAt.Unix(),
UpdatedAt: job.UpdatedAt.Unix(),
}
resp.SUCCESS(c, jobVO)
}
// DeleteJob 删除任务
func (h *AI3DHandler) DeleteJob(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := c.Query("id")
if id == "" {
id := h.GetInt(c, "id", 0)
if id == 0 {
resp.ERROR(c, "任务ID不能为空")
return
}
var job model.AI3DJob
err := h.DB.Where("id = ?", id).Where("user_id = ?", userId).First(&job).Error
err := h.service.DeleteUserJob(uint(id), uint(userId))
if err != nil {
resp.ERROR(c, err.Error())
resp.ERROR(c, "删除任务失败")
return
}
err = h.DB.Delete(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 失败的任务要退回算力
if job.Status == types.AI3DJobStatusFailed {
err = h.userService.IncreasePower(userId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.Model,
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c, gin.H{"message": "删除成功"})
}
// Download 下载3D模型
func (h *AI3DHandler) Download(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "任务ID格式错误")
return
}
job, err := h.service.GetJobById(uint(id))
if err != nil {
resp.ERROR(c, "任务不存在")
return
}
// 检查权限
if job.UserId != uint(userId) {
resp.ERROR(c, "无权限访问此任务")
return
}
// 检查任务状态
if job.Status != types.AI3DJobStatusCompleted {
resp.ERROR(c, "任务尚未完成")
return
}
if job.FileURL == "" {
resp.ERROR(c, "模型文件不存在")
return
}
// 重定向到下载链接
c.Redirect(302, job.FileURL)
}
// GetConfigs 获取3D生成配置
func (h *AI3DHandler) GetConfigs(c *gin.Context) {
var config model.Config
@@ -281,8 +153,6 @@ func (h *AI3DHandler) GetConfigs(c *gin.Context) {
config3d.Tencent.Models = models["tencent"]
}
logger.Info("config3d: ", config3d)
resp.SUCCESS(c, config3d)
}
@@ -299,9 +169,9 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
FileURL: "https://img.r9it.com/R03TQZ7PZ386RGL7PTMNGFOHAJW15WYF.glb",
PreviewURL: "/static/upload/2025/9/1756873317505073.png",
Model: "gitee-3d-v1",
Status: types.AI3DJobStatusCompleted,
Status: types.AI3DJobStatusSuccess,
ErrMsg: "",
Params: `{"prompt":"一只可爱的小猫","image_url":"","texture":true,"seed":42}`,
Params: vo.AI3DJobParams{Prompt: "一只可爱的小猫", ImageURL: "", Texture: true, Seed: 42},
CreatedAt: 1704067200, // 2024-01-01 00:00:00
UpdatedAt: 1704067800, // 2024-01-01 00:10:00
},
@@ -316,7 +186,7 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
Model: "tencent-3d-v2",
Status: types.AI3DJobStatusProcessing,
ErrMsg: "",
Params: `{"prompt":"一个现代建筑模型","image_url":"","enable_pbr":true}`,
Params: vo.AI3DJobParams{Prompt: "一个现代建筑模型", ImageURL: "", EnablePBR: true},
CreatedAt: 1704070800, // 2024-01-01 01:00:00
UpdatedAt: 1704070800, // 2024-01-01 01:00:00
},
@@ -331,7 +201,7 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
Model: "gitee-3d-v1",
Status: types.AI3DJobStatusPending,
ErrMsg: "",
Params: `{"prompt":"一辆跑车模型","image_url":"https://example.com/car.jpg","texture":false}`,
Params: vo.AI3DJobParams{Prompt: "一辆跑车模型", ImageURL: "https://example.com/car.jpg", Texture: false},
CreatedAt: 1704074400, // 2024-01-01 02:00:00
UpdatedAt: 1704074400, // 2024-01-01 02:00:00
},
@@ -346,7 +216,7 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
Model: "tencent-3d-v1",
Status: types.AI3DJobStatusFailed,
ErrMsg: "模型生成失败:输入图片质量不符合要求",
Params: `{"prompt":"一个机器人模型","image_url":"https://example.com/robot.jpg","enable_pbr":false}`,
Params: vo.AI3DJobParams{Prompt: "一个机器人模型", ImageURL: "https://example.com/robot.jpg", EnablePBR: false},
CreatedAt: 1704078000, // 2024-01-01 03:00:00
UpdatedAt: 1704078600, // 2024-01-01 03:10:00
},
@@ -359,9 +229,9 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
FileURL: "https://ai.gitee.com/a8c1af8e-26e9-4ca6-aa5c-6d4ba86bfdac",
PreviewURL: "https://ai.gitee.com/a8c1af8e-26e9-4ca6-aa5c-6d4ba86bfdac",
Model: "gitee-3d-v2",
Status: types.AI3DJobStatusCompleted,
Status: types.AI3DJobStatusSuccess,
ErrMsg: "",
Params: `{"prompt":"一个复杂的机械装置","image_url":"","texture":true,"octree_resolution":512}`,
Params: vo.AI3DJobParams{Prompt: "一个复杂的机械装置", ImageURL: "", Texture: true, OctreeResolution: 512},
CreatedAt: 1704081600, // 2024-01-01 04:00:00
UpdatedAt: 1704082200, // 2024-01-01 04:10:00
},
@@ -376,17 +246,17 @@ func (h *AI3DHandler) ListMock(c *gin.Context) {
Model: "tencent-3d-v2",
Status: types.AI3DJobStatusProcessing,
ErrMsg: "",
Params: `{"prompt":"一个科幻飞船","image_url":"","enable_pbr":true}`,
Params: vo.AI3DJobParams{Prompt: "一个科幻飞船", ImageURL: "", EnablePBR: true},
CreatedAt: 1704085200, // 2024-01-01 05:00:00
UpdatedAt: 1704085200, // 2024-01-01 05:00:00
},
}
// 创建分页响应
mockResponse := vo.ThreeDJobList{
mockResponse := vo.Page{
Page: 1,
PageSize: 10,
Total: len(mockJobs),
Total: int64(len(mockJobs)),
Items: mockJobs,
}