3D生成服务已经完成

This commit is contained in:
GeekMaster
2025-09-02 18:55:45 +08:00
parent 85b4cc0a3c
commit f8e4d2880f
40 changed files with 4920 additions and 395 deletions

View File

@@ -157,6 +157,24 @@ func LoadSystemConfig(db *gorm.DB) *types.SystemConfig {
logger.Error("load moderation config error: ", err)
}
// 加载即梦AI配置
var jimengConfig types.JimengConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyJimeng).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &jimengConfig)
if err != nil {
logger.Error("load jimeng config error: ", err)
}
// 加载3D生成配置
var ai3dConfig types.AI3DConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyAI3D).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &ai3dConfig)
if err != nil {
logger.Error("load ai3d config error: ", err)
}
return &types.SystemConfig{
Base: baseConfig,
License: license,
@@ -167,5 +185,7 @@ func LoadSystemConfig(db *gorm.DB) *types.SystemConfig {
Captcha: captchaConfig,
WxLogin: wxLoginConfig,
Moderation: moderationConfig,
Jimeng: jimengConfig,
AI3D: ai3dConfig,
}
}

58
api/core/types/ai3d.go Normal file
View File

@@ -0,0 +1,58 @@
package types
// AI3DConfig 3D生成配置
type AI3DConfig struct {
Tencent Tencent3DConfig `json:"tencent,omitempty"`
Gitee Gitee3DConfig `json:"gitee,omitempty"`
}
// Tencent3DConfig 腾讯云3D配置
type Tencent3DConfig struct {
SecretId string `json:"secret_id,omitempty"`
SecretKey string `json:"secret_key,omitempty"`
Region string `json:"region,omitempty"`
Enabled bool `json:"enabled,omitempty"`
Models []AI3DModel `json:"models,omitempty"`
}
// Gitee3DConfig Gitee 3D配置
type Gitee3DConfig struct {
APIKey string `json:"api_key,omitempty"`
Enabled bool `json:"enabled,omitempty"`
Models []AI3DModel `json:"models,omitempty"`
}
// AI3DJobResult 3D任务结果
type AI3DJobResult struct {
JobId string `json:"job_id"` // 任务ID
Status string `json:"status"` // 任务状态
Progress int `json:"progress"` // 任务进度 (0-100)
FileURL string `json:"file_url"` // 3D模型文件URL
PreviewURL string `json:"preview_url"` // 预览图片URL
ErrorMsg string `json:"error_msg"` // 错误信息
}
// AI3DModel 3D模型配置
type AI3DModel struct {
Name string `json:"name"` // 模型名称
Desc string `json:"desc"` // 模型描述
Power int `json:"power"` // 算力消耗
Formats []string `json:"formats"` // 支持输出的文件格式
}
// AI3DJobRequest 3D任务请求
type AI3DJobRequest struct {
Type string `json:"type"` // API类型 (tencent/gitee)
Model string `json:"model"` // 3D模型类型
Prompt string `json:"prompt"` // 文本提示词
ImageURL string `json:"image_url"` // 输入图片URL
Power int `json:"power"` // 消耗算力
}
// AI3DJobStatus 3D任务状态
const (
AI3DJobStatusPending = "pending" // 等待中
AI3DJobStatusProcessing = "processing" // 处理中
AI3DJobStatusCompleted = "completed" // 已完成
AI3DJobStatusFailed = "failed" // 失败
)

View File

@@ -108,6 +108,7 @@ type SystemConfig struct {
Captcha CaptchaConfig
WxLogin WxLoginConfig
Jimeng JimengConfig
AI3D AI3DConfig
License License
Moderation ModerationConfig
}
@@ -127,4 +128,6 @@ const (
ConfigKeyOss = "oss"
ConfigKeyPayment = "payment"
ConfigKeyModeration = "moderation"
ConfigKeyAI3D = "ai3d"
ConfigKeyJimeng = "jimeng"
)

View File

@@ -33,6 +33,8 @@ require (
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.3.1
github.com/syndtr/goleveldb v1.0.0
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ai3d v1.1.0
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.1.21
golang.org/x/image v0.15.0
)
@@ -46,8 +48,6 @@ require (
github.com/go-pay/xtime v0.0.2 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ai3d v1.1.0 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.1.20 // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect

View File

@@ -248,8 +248,8 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ai3d v1.1.0 h1:hOyYsl35o74hOhnnPVQIK/bdSIPNp3TKJlCEOXGO7ms=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ai3d v1.1.0/go.mod h1:3689peGF1zp+P9c+GnUcAzkMp+kXi0Tr44zeQ57Z+7Y=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.1.0/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.1.20 h1:8B80/p+WvzBVz+jM6dosTcfhRe7Jotpyqj0NoGW1wfE=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.1.20/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.1.21 h1:ikHhyiq1PiPytUMtEblKPkbf0zzTEi3CpE9z0MARlqY=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.1.21/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=

View File

@@ -0,0 +1,216 @@
package admin
import (
"strconv"
"geekai/core"
"geekai/core/types"
"geekai/service/ai3d"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// AI3DHandler 3D管理处理器
type AI3DHandler struct {
app *core.AppServer
db *gorm.DB
service *ai3d.Service
}
// NewAI3DHandler 创建3D管理处理器
func NewAI3DHandler(app *core.AppServer, db *gorm.DB, service *ai3d.Service) *AI3DHandler {
return &AI3DHandler{
app: app,
db: db,
service: service,
}
}
// RegisterRoutes 注册路由
func (h *AI3DHandler) RegisterRoutes() {
admin := h.app.Engine.Group("/api/admin/ai3d")
{
admin.GET("/jobs", h.GetJobList)
admin.GET("/jobs/:id", h.GetJobDetail)
admin.DELETE("/jobs/:id", h.DeleteJob)
admin.GET("/stats", h.GetStats)
admin.GET("/models", h.GetModels)
admin.POST("/config", h.SaveConfig)
}
}
// GetJobList 获取任务列表
func (h *AI3DHandler) GetJobList(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
status := c.Query("status")
jobType := c.Query("type")
userIdStr := c.Query("user_id")
var userId uint
if userIdStr != "" {
if id, err := strconv.ParseUint(userIdStr, 10, 32); err == nil {
userId = uint(id)
}
}
// 构建查询条件
query := h.db.Model(&model.AI3DJob{})
if status != "" {
query = query.Where("status = ?", status)
}
if jobType != "" {
query = query.Where("type = ?", jobType)
}
if userId > 0 {
query = query.Where("user_id = ?", userId)
}
// 获取总数
var total int64
query.Count(&total)
// 获取分页数据
var jobs []model.AI3DJob
offset := (page - 1) * pageSize
err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&jobs).Error
if err != nil {
resp.ERROR(c, "获取任务列表失败")
return
}
// 转换为VO
var jobList []vo.AI3DJob
for _, job := range jobs {
var jobVo vo.AI3DJob
err = utils.CopyObject(job, &jobVo)
if err != nil {
continue
}
jobList = append(jobList, jobVo)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, jobList))
}
// GetJobDetail 获取任务详情
func (h *AI3DHandler) GetJobDetail(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "无效的任务ID")
return
}
var job model.AI3DJob
err = h.db.First(&job, uint(id)).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
resp.ERROR(c, "任务不存在")
} else {
resp.ERROR(c, "获取任务详情失败")
}
return
}
var jobVo vo.AI3DJob
err = utils.CopyObject(job, &jobVo)
if err != nil {
resp.ERROR(c, "获取任务详情失败")
return
}
resp.SUCCESS(c, jobVo)
}
// DeleteJob 删除任务
func (h *AI3DHandler) DeleteJob(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "无效的任务ID")
return
}
// 检查任务是否存在
var job model.AI3DJob
err = h.db.First(&job, uint(id)).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
resp.ERROR(c, "任务不存在")
} else {
resp.ERROR(c, "获取任务失败")
}
return
}
// 删除任务
err = h.db.Delete(&job).Error
if err != nil {
resp.ERROR(c, "删除任务失败")
return
}
resp.SUCCESS(c, "删除成功")
}
// GetStats 获取统计数据
func (h *AI3DHandler) GetStats(c *gin.Context) {
var stats struct {
Pending int64 `json:"pending"`
Processing int64 `json:"processing"`
Completed int64 `json:"completed"`
Failed int64 `json:"failed"`
}
// 统计各状态的任务数量
h.db.Model(&model.AI3DJob{}).Where("status = ?", "pending").Count(&stats.Pending)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "processing").Count(&stats.Processing)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "completed").Count(&stats.Completed)
h.db.Model(&model.AI3DJob{}).Where("status = ?", "failed").Count(&stats.Failed)
resp.SUCCESS(c, stats)
}
// GetModels 获取配置
func (h *AI3DHandler) GetModels(c *gin.Context) {
models := h.service.GetSupportedModels()
resp.SUCCESS(c, models)
}
// SaveGlobalSettings 保存全局配置
func (h *AI3DHandler) SaveConfig(c *gin.Context) {
var config types.AI3DConfig
err := c.ShouldBindJSON(&config)
if err != nil {
resp.ERROR(c, "参数错误")
return
}
var exist model.Config
err = h.db.Where("name", types.ConfigKeyAI3D).First(&exist).Error
if err != nil {
exist.Name = types.ConfigKeyAI3D
exist.Value = utils.JsonEncode(config)
err = h.db.Create(&exist).Error
} else {
exist.Value = utils.JsonEncode(config)
err = h.db.Updates(&exist).Error
}
if err != nil {
resp.ERROR(c, "保存配置失败")
return
}
h.service.UpdateConfig(config)
h.app.SysConfig.AI3D = config
resp.SUCCESS(c, "保存成功")
}

View File

@@ -8,13 +8,11 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/moderation"
"geekai/service/oss"
"geekai/service/payment"
"geekai/service/sms"
@@ -28,17 +26,16 @@ import (
type ConfigHandler struct {
handler.BaseHandler
licenseService *service.LicenseService
sysConfig *types.SystemConfig
alipayService *payment.AlipayService
wxpayService *payment.WxPayService
epayService *payment.EPayService
smsManager *sms.SmsManager
uploaderManager *oss.UploaderManager
smtpService *service.SmtpService
captchaService *service.CaptchaService
wxLoginService *service.WxLoginService
moderationManager *moderation.ServiceManager
licenseService *service.LicenseService
sysConfig *types.SystemConfig
alipayService *payment.AlipayService
wxpayService *payment.WxPayService
epayService *payment.EPayService
smsManager *sms.SmsManager
uploaderManager *oss.UploaderManager
smtpService *service.SmtpService
captchaService *service.CaptchaService
wxLoginService *service.WxLoginService
}
func NewConfigHandler(
@@ -54,21 +51,19 @@ func NewConfigHandler(
smtpService *service.SmtpService,
captchaService *service.CaptchaService,
wxLoginService *service.WxLoginService,
moderationManager *moderation.ServiceManager,
) *ConfigHandler {
return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
licenseService: licenseService,
sysConfig: sysConfig,
alipayService: alipayService,
wxpayService: wxpayService,
epayService: epayService,
smsManager: smsManager,
uploaderManager: uploaderManager,
moderationManager: moderationManager,
smtpService: smtpService,
captchaService: captchaService,
wxLoginService: wxLoginService,
BaseHandler: handler.BaseHandler{App: app, DB: db},
licenseService: licenseService,
sysConfig: sysConfig,
alipayService: alipayService,
wxpayService: wxpayService,
epayService: epayService,
smsManager: smsManager,
uploaderManager: uploaderManager,
smtpService: smtpService,
captchaService: captchaService,
wxLoginService: wxLoginService,
}
}
@@ -91,8 +86,6 @@ func (h *ConfigHandler) RegisterRoutes() {
rg.POST("update/sms", h.UpdateSms)
rg.POST("update/oss", h.UpdateOss)
rg.POST("update/smtp", h.UpdateStmp)
rg.POST("update/moderation", h.UpdateModeration)
rg.POST("moderation/test", h.TestModeration)
rg.GET("get", h.Get)
rg.POST("license/active", h.Active)
rg.GET("license/get", h.GetLicense)
@@ -450,90 +443,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license)
}
// UpdateModeration 更新文本审查配置
func (h *ConfigHandler) UpdateModeration(c *gin.Context) {
var data types.ModerationConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyModeration, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
h.moderationManager.UpdateConfig(data)
h.sysConfig.Moderation = data
resp.SUCCESS(c, data)
}
// 测试结果类型,用于前端显示
type ModerationTestResult struct {
IsAbnormal bool `json:"isAbnormal"`
Details []ModerationTestDetail `json:"details"`
}
type ModerationTestDetail struct {
Category string `json:"category"`
Description string `json:"description"`
Confidence string `json:"confidence"`
IsCategory bool `json:"isCategory"`
}
// TestModeration 测试文本审查服务
func (h *ConfigHandler) TestModeration(c *gin.Context) {
var data struct {
Text string `json:"text"`
Service string `json:"service"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Text == "" {
resp.ERROR(c, "测试文本不能为空")
return
}
// 检查是否启用了文本审查
if !h.sysConfig.Moderation.Enable {
resp.ERROR(c, "文本审查服务未启用")
return
}
// 获取当前激活的审核服务
service := h.moderationManager.GetService()
// 执行文本审核
result, err := service.Moderate(data.Text)
if err != nil {
resp.ERROR(c, "审核服务调用失败: "+err.Error())
return
}
// 转换为前端需要的格式
testResult := ModerationTestResult{
IsAbnormal: result.Flagged,
Details: make([]ModerationTestDetail, 0),
}
// 构建详细信息
for category, description := range types.ModerationCategories {
score := result.CategoryScores[category]
isCategory := result.Categories[category]
testResult.Details = append(testResult.Details, ModerationTestDetail{
Category: category,
Description: description,
Confidence: fmt.Sprintf("%.2f", score),
IsCategory: isCategory,
})
}
resp.SUCCESS(c, testResult)
}

View File

@@ -21,18 +21,18 @@ import (
// AdminJimengHandler 管理后台即梦AI处理器
type AdminJimengHandler struct {
handler.BaseHandler
jimengService *jimeng.Service
userService *service.UserService
uploader *oss.UploaderManager
jimengClient *jimeng.Client
userService *service.UserService
uploader *oss.UploaderManager
}
// NewAdminJimengHandler 创建管理后台即梦AI处理器
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengClient *jimeng.Client, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
return &AdminJimengHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
jimengService: jimengService,
userService: userService,
uploader: uploader,
BaseHandler: handler.BaseHandler{App: app, DB: db},
jimengClient: jimengClient,
userService: userService,
uploader: uploader,
}
}
@@ -43,7 +43,6 @@ func (h *AdminJimengHandler) RegisterRoutes() {
rg.GET("/jobs/:id", h.JobDetail)
rg.POST("/jobs/remove", h.BatchRemove)
rg.GET("/stats", h.Stats)
rg.GET("/config", h.GetConfig)
rg.POST("/config/update", h.UpdateConfig)
}
@@ -213,12 +212,6 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
resp.SUCCESS(c, result)
}
// GetConfig 获取即梦AI配置
func (h *AdminJimengHandler) GetConfig(c *gin.Context) {
jimengConfig := h.jimengService.GetConfig()
resp.SUCCESS(c, jimengConfig)
}
// UpdateConfig 更新即梦AI配置
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
var req types.JimengConfig
@@ -266,9 +259,9 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
// 保存配置
tx := h.DB.Begin()
value := utils.JsonEncode(&req)
config := model.Config{Name: "jimeng", Value: value}
config := model.Config{Name: types.ConfigKeyJimeng, Value: value}
err := tx.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error
err := tx.FirstOrCreate(&config).Error
if err != nil {
resp.ERROR(c, "保存配置失败: "+err.Error())
return
@@ -284,13 +277,14 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
}
// 更新服务中的客户端配置
updateErr := h.jimengService.UpdateClientConfig(req.AccessKey, req.SecretKey)
if updateErr != nil {
resp.ERROR(c, updateErr.Error())
err = h.jimengClient.UpdateConfig(req)
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return
}
tx.Commit()
h.App.SysConfig.Jimeng = req
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
}

View File

@@ -8,10 +8,12 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/service/moderation"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
@@ -22,10 +24,12 @@ import (
type ModerationHandler struct {
handler.BaseHandler
sysConfig *types.SystemConfig
moderationManager *moderation.ServiceManager
}
func NewModerationHandler(app *core.AppServer, db *gorm.DB) *ModerationHandler {
return &ModerationHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
func NewModerationHandler(app *core.AppServer, db *gorm.DB, sysConfig *types.SystemConfig, moderationManager *moderation.ServiceManager) *ModerationHandler {
return &ModerationHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, sysConfig: sysConfig, moderationManager: moderationManager}
}
// RegisterRoutes 注册路由
@@ -39,6 +43,8 @@ func (h *ModerationHandler) RegisterRoutes() {
group.GET("remove", h.Remove)
group.POST("batch-remove", h.BatchRemove)
group.GET("source-list", h.GetSourceList)
group.POST("config", h.UpdateModeration)
group.POST("test", h.TestModeration)
}
}
@@ -229,3 +235,90 @@ func (h *ModerationHandler) GetSourceList(c *gin.Context) {
resp.SUCCESS(c, sources)
}
// UpdateModeration 更新文本审查配置
func (h *ModerationHandler) UpdateModeration(c *gin.Context) {
var data types.ModerationConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Where("name", types.ConfigKeyModeration).FirstOrCreate(&model.Config{Name: types.ConfigKeyModeration, Value: utils.JsonEncode(data)}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
h.moderationManager.UpdateConfig(data)
h.sysConfig.Moderation = data
resp.SUCCESS(c, data)
}
// 测试结果类型,用于前端显示
type ModerationTestResult struct {
IsAbnormal bool `json:"isAbnormal"`
Details []ModerationTestDetail `json:"details"`
}
type ModerationTestDetail struct {
Category string `json:"category"`
Description string `json:"description"`
Confidence string `json:"confidence"`
IsCategory bool `json:"isCategory"`
}
// TestModeration 测试文本审查服务
func (h *ModerationHandler) TestModeration(c *gin.Context) {
var data struct {
Text string `json:"text"`
Service string `json:"service"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Text == "" {
resp.ERROR(c, "测试文本不能为空")
return
}
// 检查是否启用了文本审查
if !h.sysConfig.Moderation.Enable {
resp.ERROR(c, "文本审查服务未启用")
return
}
// 获取当前激活的审核服务
service := h.moderationManager.GetService()
// 执行文本审核
result, err := service.Moderate(data.Text)
if err != nil {
resp.ERROR(c, "审核服务调用失败: "+err.Error())
return
}
// 转换为前端需要的格式
testResult := ModerationTestResult{
IsAbnormal: result.Flagged,
Details: make([]ModerationTestDetail, 0),
}
// 构建详细信息
for category, description := range types.ModerationCategories {
score := result.CategoryScores[category]
isCategory := result.Categories[category]
testResult.Details = append(testResult.Details, ModerationTestDetail{
Category: category,
Description: description,
Confidence: fmt.Sprintf("%.2f", score),
IsCategory: isCategory,
})
}
resp.SUCCESS(c, testResult)
}

236
api/handler/ai3d_handler.go Normal file
View File

@@ -0,0 +1,236 @@
package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/ai3d"
"geekai/store/vo"
"geekai/utils/resp"
"strconv"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type AI3DHandler struct {
BaseHandler
service *ai3d.Service
userService *service.UserService
}
func NewAI3DHandler(app *core.AppServer, db *gorm.DB, service *ai3d.Service, userService *service.UserService) *AI3DHandler {
return &AI3DHandler{
service: service,
userService: userService,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
}
// RegisterRoutes 注册路由
func (h *AI3DHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/3d/")
// 公开接口,不需要授权
group.GET("models/:type", h.GetModels)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("generate", h.Generate)
group.GET("jobs", h.JobList)
group.GET("job/:id", h.JobDetail)
group.DELETE("job/:id", h.DeleteJob)
group.GET("download/:id", h.Download)
}
}
// Generate 创建3D生成任务
func (h *AI3DHandler) Generate(c *gin.Context) {
var request vo.AI3DJobCreate
if err := c.ShouldBindJSON(&request); err != nil {
resp.ERROR(c, "参数错误")
return
}
// 验证必填参数
if request.Type == "" || request.Model == "" || request.Power <= 0 {
resp.ERROR(c, "缺少必要参数")
return
}
// 获取用户ID
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
// 创建任务
job, err := h.service.CreateJob(uint(userId), request)
if err != nil {
resp.ERROR(c, fmt.Sprintf("创建任务失败: %v", err))
return
}
resp.SUCCESS(c, gin.H{
"job_id": job.Id,
"message": "任务创建成功",
})
}
// JobList 获取任务列表
func (h *AI3DHandler) JobList(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 10
}
jobList, err := h.service.GetJobList(uint(userId), page, pageSize)
if err != nil {
resp.ERROR(c, fmt.Sprintf("获取任务列表失败: %v", err))
return
}
resp.SUCCESS(c, jobList)
}
// JobDetail 获取任务详情
func (h *AI3DHandler) JobDetail(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "任务ID格式错误")
return
}
job, err := h.service.GetJobById(uint(id))
if err != nil {
resp.ERROR(c, "任务不存在")
return
}
// 检查权限
if job.UserId != uint(userId) {
resp.ERROR(c, "无权限访问此任务")
return
}
// 转换为VO
jobVO := vo.AI3DJob{
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
Power: job.Power,
TaskId: job.TaskId,
ImgURL: job.FileURL,
PreviewURL: job.PreviewURL,
Model: job.Model,
Status: job.Status,
ErrMsg: job.ErrMsg,
Params: job.Params,
CreatedAt: job.CreatedAt.Unix(),
UpdatedAt: job.UpdatedAt.Unix(),
}
resp.SUCCESS(c, jobVO)
}
// DeleteJob 删除任务
func (h *AI3DHandler) DeleteJob(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "任务ID格式错误")
return
}
err = h.service.DeleteJob(uint(id), uint(userId))
if err != nil {
resp.ERROR(c, fmt.Sprintf("删除任务失败: %v", err))
return
}
resp.SUCCESS(c, gin.H{"message": "删除成功"})
}
// Download 下载3D模型
func (h *AI3DHandler) Download(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.ERROR(c, "用户未登录")
return
}
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "任务ID格式错误")
return
}
job, err := h.service.GetJobById(uint(id))
if err != nil {
resp.ERROR(c, "任务不存在")
return
}
// 检查权限
if job.UserId != uint(userId) {
resp.ERROR(c, "无权限访问此任务")
return
}
// 检查任务状态
if job.Status != types.AI3DJobStatusCompleted {
resp.ERROR(c, "任务尚未完成")
return
}
if job.FileURL == "" {
resp.ERROR(c, "模型文件不存在")
return
}
// 重定向到下载链接
c.Redirect(302, job.FileURL)
}
// GetModels 获取支持的模型列表
func (h *AI3DHandler) GetModels(c *gin.Context) {
models := h.service.GetSupportedModels()
if len(models) == 0 {
resp.ERROR(c, "无可用3D模型")
return
}
resp.SUCCESS(c, models)
}

View File

@@ -435,7 +435,7 @@ func (h *JimengHandler) Retry(c *gin.Context) {
// getPowerFromConfig 从配置中获取指定类型的算力消耗
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
config := h.jimengService.GetConfig()
config := h.App.SysConfig.Jimeng
switch taskType {
case model.JMTaskTypeTextToImage:
@@ -457,7 +457,7 @@ func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
// GetPowerConfig 获取即梦各任务类型算力消耗配置
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
config := h.jimengService.GetConfig()
config := h.App.SysConfig.Jimeng
resp.SUCCESS(c, gin.H{
"text_to_image": config.Power.TextToImage,
"image_to_image": config.Power.ImageToImage,

View File

@@ -16,6 +16,7 @@ import (
"geekai/handler/admin"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/ai3d"
"geekai/service/dalle"
"geekai/service/jimeng"
"geekai/service/mj"
@@ -210,10 +211,19 @@ func main() {
}),
// 即梦AI 服务
fx.Provide(jimeng.NewClient),
fx.Provide(jimeng.NewService),
fx.Invoke(func(service *jimeng.Service) {
service.Start()
}),
// 3D生成服务
fx.Provide(ai3d.NewTencent3DClient),
fx.Provide(ai3d.NewGitee3DClient),
fx.Provide(ai3d.NewService),
fx.Invoke(func(s *ai3d.Service) {
s.Run()
}),
fx.Provide(service.NewSnowflake),
// 创建短信服务
@@ -383,6 +393,16 @@ func main() {
h.RegisterRoutes()
}),
// 3D生成处理器
fx.Provide(handler.NewAI3DHandler),
fx.Invoke(func(s *core.AppServer, h *handler.AI3DHandler) {
h.RegisterRoutes()
}),
fx.Provide(admin.NewAI3DHandler),
fx.Invoke(func(s *core.AppServer, h *admin.AI3DHandler) {
h.RegisterRoutes()
}),
// 即梦AI 路由
fx.Invoke(func(s *core.AppServer, h *handler.JimengHandler) {
h.RegisterRoutes()

View File

@@ -0,0 +1,150 @@
package ai3d
import (
"encoding/json"
"fmt"
"geekai/core/types"
"time"
"github.com/imroc/req/v3"
)
type Gitee3DClient struct {
httpClient *req.Client
config types.Gitee3DConfig
apiURL string
}
type Gitee3DParams struct {
Prompt string `json:"prompt"` // 文本提示词
ImageURL string `json:"image_url"` // 输入图片URL
ResultFormat string `json:"result_format"` // 输出格式
}
type Gitee3DResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
TaskID string `json:"task_id"`
} `json:"data"`
}
type Gitee3DQueryResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Status string `json:"status"`
Progress int `json:"progress"`
ResultURL string `json:"result_url"`
PreviewURL string `json:"preview_url"`
ErrorMsg string `json:"error_msg"`
} `json:"data"`
}
func NewGitee3DClient(sysConfig *types.SystemConfig) *Gitee3DClient {
return &Gitee3DClient{
httpClient: req.C().SetTimeout(time.Minute * 3),
config: sysConfig.AI3D.Gitee,
apiURL: "https://ai.gitee.com/v1/async/image-to-3d",
}
}
func (c *Gitee3DClient) UpdateConfig(config types.Gitee3DConfig) {
c.config = config
}
// SubmitJob 提交3D生成任务
func (c *Gitee3DClient) SubmitJob(params Gitee3DParams) (string, error) {
requestBody := map[string]any{
"prompt": params.Prompt,
"image_url": params.ImageURL,
"result_format": params.ResultFormat,
}
response, err := c.httpClient.R().
SetHeader("Authorization", "Bearer "+c.config.APIKey).
SetHeader("Content-Type", "application/json").
SetBody(requestBody).
Post(c.apiURL + "/async/image-to-3d")
if err != nil {
return "", fmt.Errorf("failed to submit gitee 3D job: %v", err)
}
var giteeResp Gitee3DResponse
if err := json.Unmarshal(response.Bytes(), &giteeResp); err != nil {
return "", fmt.Errorf("failed to parse gitee response: %v", err)
}
if giteeResp.Code != 0 {
return "", fmt.Errorf("gitee API error: %s", giteeResp.Message)
}
if giteeResp.Data.TaskID == "" {
return "", fmt.Errorf("no task ID returned from gitee 3D API")
}
return giteeResp.Data.TaskID, nil
}
// QueryJob 查询任务状态
func (c *Gitee3DClient) QueryJob(taskId string) (*types.AI3DJobResult, error) {
response, err := c.httpClient.R().
SetHeader("Authorization", "Bearer "+c.config.APIKey).
Get(fmt.Sprintf("%s/task/%s/get", c.apiURL, taskId))
if err != nil {
return nil, fmt.Errorf("failed to query gitee 3D job: %v", err)
}
var giteeResp Gitee3DQueryResponse
if err := json.Unmarshal(response.Bytes(), &giteeResp); err != nil {
return nil, fmt.Errorf("failed to parse gitee query response: %v", err)
}
if giteeResp.Code != 0 {
return nil, fmt.Errorf("gitee API error: %s", giteeResp.Message)
}
result := &types.AI3DJobResult{
JobId: taskId,
Status: c.convertStatus(giteeResp.Data.Status),
Progress: giteeResp.Data.Progress,
}
// 根据状态设置结果
switch giteeResp.Data.Status {
case "completed":
result.FileURL = giteeResp.Data.ResultURL
result.PreviewURL = giteeResp.Data.PreviewURL
case "failed":
result.ErrorMsg = giteeResp.Data.ErrorMsg
}
return result, nil
}
// convertStatus 转换Gitee状态到系统状态
func (c *Gitee3DClient) convertStatus(giteeStatus string) string {
switch giteeStatus {
case "pending":
return types.AI3DJobStatusPending
case "processing":
return types.AI3DJobStatusProcessing
case "completed":
return types.AI3DJobStatusCompleted
case "failed":
return types.AI3DJobStatusFailed
default:
return types.AI3DJobStatusPending
}
}
// GetSupportedModels 获取支持的模型列表
func (c *Gitee3DClient) GetSupportedModels() []types.AI3DModel {
return []types.AI3DModel{
{Name: "Hunyuan3D-2", Power: 100, Formats: []string{"GLB"}, Desc: "Hunyuan3D-2 是腾讯混元团队推出的高质量 3D 生成模型,具备高保真度、细节丰富和高效生成的特点,可快速将文本或图像转换为逼真的 3D 物体。"},
{Name: "Step1X-3D", Power: 55, Formats: []string{"GLB", "STL"}, Desc: "Step1X-3D 是一款由阶跃星辰StepFun与光影焕像LightIllusions联合研发并开源的高保真 3D 生成模型,专为高质量、可控的 3D 内容创作而设计。"},
{Name: "Hi3DGen", Power: 35, Formats: []string{"GLB", "STL"}, Desc: "Hi3DGen 是一个 AI 工具,它可以把你上传的普通图片,智能转换成有“立体感”的图片(法线图),常用于制作 3D 效果,比如游戏建模、虚拟现实、动画制作等。"},
}
}

327
api/service/ai3d/service.go Normal file
View File

@@ -0,0 +1,327 @@
package ai3d
import (
"encoding/json"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/store"
"geekai/store/model"
"geekai/store/vo"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
// Service 3D生成服务
type Service struct {
db *gorm.DB
taskQueue *store.RedisQueue
tencentClient *Tencent3DClient
giteeClient *Gitee3DClient
}
// NewService 创建3D生成服务
func NewService(db *gorm.DB, redisCli *redis.Client, tencentClient *Tencent3DClient, giteeClient *Gitee3DClient) *Service {
return &Service{
db: db,
taskQueue: store.NewRedisQueue("3D_Task_Queue", redisCli),
tencentClient: tencentClient,
giteeClient: giteeClient,
}
}
// CreateJob 创建3D生成任务
func (s *Service) CreateJob(userId uint, request vo.AI3DJobCreate) (*model.AI3DJob, error) {
// 创建任务记录
job := &model.AI3DJob{
UserId: userId,
Type: request.Type,
Power: request.Power,
Model: request.Model,
Status: types.AI3DJobStatusPending,
}
// 序列化参数
params := map[string]any{
"prompt": request.Prompt,
"image_url": request.ImageURL,
"model": request.Model,
"power": request.Power,
}
paramsJSON, _ := json.Marshal(params)
job.Params = string(paramsJSON)
// 保存到数据库
if err := s.db.Create(job).Error; err != nil {
return nil, fmt.Errorf("failed to create 3D job: %v", err)
}
// 将任务添加到队列
s.PushTask(job)
return job, nil
}
// PushTask 将任务添加到队列
func (s *Service) PushTask(job *model.AI3DJob) {
logger.Infof("add a new 3D task to the queue: %+v", job)
if err := s.taskQueue.RPush(job); err != nil {
logger.Errorf("push 3D task to queue failed: %v", err)
}
}
// Run 启动任务处理器
func (s *Service) Run() {
// 将数据库中未完成的任务加载到队列
var jobs []model.AI3DJob
s.db.Where("status IN ?", []string{types.AI3DJobStatusPending, types.AI3DJobStatusProcessing}).Find(&jobs)
for _, job := range jobs {
s.PushTask(&job)
}
logger.Info("Starting 3D job consumer...")
go func() {
for {
var job model.AI3DJob
err := s.taskQueue.LPop(&job)
if err != nil {
logger.Errorf("taking 3D task with error: %v", err)
continue
}
logger.Infof("handle a new 3D task: %+v", job)
go func() {
if err := s.processJob(&job); err != nil {
logger.Errorf("error processing 3D job: %v", err)
s.updateJobStatus(&job, types.AI3DJobStatusFailed, 0, err.Error())
}
}()
}
}()
}
// processJob 处理3D任务
func (s *Service) processJob(job *model.AI3DJob) error {
// 更新状态为处理中
s.updateJobStatus(job, types.AI3DJobStatusProcessing, 10, "")
// 解析参数
var params map[string]any
if err := json.Unmarshal([]byte(job.Params), &params); err != nil {
return fmt.Errorf("failed to parse job params: %v", err)
}
var taskId string
var err error
// 根据类型选择客户端
switch job.Type {
case "tencent":
if s.tencentClient == nil {
return fmt.Errorf("tencent 3D client not initialized")
}
tencentParams := Tencent3DParams{
Prompt: s.getString(params, "prompt"),
ImageURL: s.getString(params, "image_url"),
ResultFormat: job.Model,
EnablePBR: false,
}
taskId, err = s.tencentClient.SubmitJob(tencentParams)
case "gitee":
if s.giteeClient == nil {
return fmt.Errorf("gitee 3D client not initialized")
}
giteeParams := Gitee3DParams{
Prompt: s.getString(params, "prompt"),
ImageURL: s.getString(params, "image_url"),
ResultFormat: job.Model,
}
taskId, err = s.giteeClient.SubmitJob(giteeParams)
default:
return fmt.Errorf("unsupported 3D API type: %s", job.Type)
}
if err != nil {
return fmt.Errorf("failed to submit 3D job: %v", err)
}
// 更新任务ID
job.TaskId = taskId
s.db.Model(job).Update("task_id", taskId)
// 开始轮询任务状态
go s.pollJobStatus(job)
return nil
}
// pollJobStatus 轮询任务状态
func (s *Service) pollJobStatus(job *model.AI3DJob) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
result, err := s.queryJobStatus(job)
if err != nil {
logger.Errorf("failed to query job status: %v", err)
continue
}
// 更新进度
s.updateJobStatus(job, result.Status, result.Progress, result.ErrorMsg)
// 如果任务完成或失败,停止轮询
if result.Status == types.AI3DJobStatusCompleted || result.Status == types.AI3DJobStatusFailed {
if result.Status == types.AI3DJobStatusCompleted {
// 更新结果文件URL
s.db.Model(job).Updates(map[string]interface{}{
"img_url": result.FileURL,
"preview_url": result.PreviewURL,
})
}
return
}
}
}
}
// queryJobStatus 查询任务状态
func (s *Service) queryJobStatus(job *model.AI3DJob) (*types.AI3DJobResult, error) {
switch job.Type {
case "tencent":
if s.tencentClient == nil {
return nil, fmt.Errorf("tencent 3D client not initialized")
}
return s.tencentClient.QueryJob(job.TaskId)
case "gitee":
if s.giteeClient == nil {
return nil, fmt.Errorf("gitee 3D client not initialized")
}
return s.giteeClient.QueryJob(job.TaskId)
default:
return nil, fmt.Errorf("unsupported 3D API type: %s", job.Type)
}
}
// updateJobStatus 更新任务状态
func (s *Service) updateJobStatus(job *model.AI3DJob, status string, progress int, errMsg string) {
updates := map[string]interface{}{
"status": status,
"progress": progress,
"updated_at": time.Now(),
}
if errMsg != "" {
updates["err_msg"] = errMsg
}
if err := s.db.Model(job).Updates(updates).Error; err != nil {
logger.Errorf("failed to update job status: %v", err)
}
}
// GetJobList 获取任务列表
func (s *Service) GetJobList(userId uint, page, pageSize int) (*vo.Page, error) {
var total int64
var jobs []model.AI3DJob
// 查询总数
if err := s.db.Model(&model.AI3DJob{}).Where("user_id = ?", userId).Count(&total).Error; err != nil {
return nil, err
}
// 查询任务列表
offset := (page - 1) * pageSize
if err := s.db.Where("user_id = ?", userId).Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&jobs).Error; err != nil {
return nil, err
}
// 转换为VO
var jobList []vo.AI3DJob
for _, job := range jobs {
jobVO := vo.AI3DJob{
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
Power: job.Power,
TaskId: job.TaskId,
ImgURL: job.FileURL,
PreviewURL: job.PreviewURL,
Model: job.Model,
Status: job.Status,
ErrMsg: job.ErrMsg,
Params: job.Params,
CreatedAt: job.CreatedAt.Unix(),
UpdatedAt: job.UpdatedAt.Unix(),
}
jobList = append(jobList, jobVO)
}
return &vo.Page{
Page: page,
PageSize: pageSize,
Total: total,
Items: jobList,
}, nil
}
// GetJobById 根据ID获取任务
func (s *Service) GetJobById(id uint) (*model.AI3DJob, error) {
var job model.AI3DJob
if err := s.db.Where("id = ?", id).First(&job).Error; err != nil {
return nil, err
}
return &job, nil
}
// DeleteJob 删除任务
func (s *Service) DeleteJob(id uint, userId uint) error {
var job model.AI3DJob
if err := s.db.Where("id = ? AND user_id = ?", id, userId).First(&job).Error; err != nil {
return err
}
// 如果任务已完成,退还算力
if job.Status == types.AI3DJobStatusCompleted {
// TODO: 实现算力退还逻辑
logger2.GetLogger().Infof("should refund power %d for user %d", job.Power, userId)
}
return s.db.Delete(&job).Error
}
// GetSupportedModels 获取支持的模型列表
func (s *Service) GetSupportedModels() map[string][]types.AI3DModel {
models := make(map[string][]types.AI3DModel)
if s.tencentClient != nil {
models["tencent"] = s.tencentClient.GetSupportedModels()
}
if s.giteeClient != nil {
models["gitee"] = s.giteeClient.GetSupportedModels()
}
return models
}
func (s *Service) UpdateConfig(config types.AI3DConfig) {
if s.tencentClient != nil {
s.tencentClient.UpdateConfig(config.Tencent)
}
if s.giteeClient != nil {
s.giteeClient.UpdateConfig(config.Gitee)
}
}
// getString 从map中获取字符串值
func (s *Service) getString(params map[string]interface{}, key string) string {
if val, ok := params[key]; ok {
if str, ok := val.(string); ok {
return str
}
}
return ""
}

View File

@@ -0,0 +1,158 @@
package ai3d
import (
"fmt"
"geekai/core/types"
tencent3d "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ai3d/v20250513"
tencentcloud "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
)
type Tencent3DClient struct {
client *tencent3d.Client
config types.Tencent3DConfig
}
type Tencent3DParams struct {
Prompt string `json:"prompt"` // 文本提示词
ImageURL string `json:"image_url"` // 输入图片URL
ResultFormat string `json:"result_format"` // 输出格式
EnablePBR bool `json:"enable_pbr"` // 是否开启PBR材质
MultiViewImages []ViewImage `json:"multi_view_images,omitempty"` // 多视角图片
}
type ViewImage struct {
ViewType string `json:"view_type"` // 视角类型 (left/right/back)
ViewImageURL string `json:"view_image_url"` // 图片URL
}
func NewTencent3DClient(sysConfig *types.SystemConfig) (*Tencent3DClient, error) {
config := sysConfig.AI3D.Tencent
credential := tencentcloud.NewCredential(config.SecretId, config.SecretKey)
cpf := profile.NewClientProfile()
cpf.HttpProfile.Endpoint = "ai3d.tencentcloudapi.com"
client, err := tencent3d.NewClient(credential, config.Region, cpf)
if err != nil {
return nil, fmt.Errorf("failed to create tencent 3D client: %v", err)
}
return &Tencent3DClient{
client: client,
config: config,
}, nil
}
func (c *Tencent3DClient) UpdateConfig(config types.Tencent3DConfig) error {
c.config = config
credential := tencentcloud.NewCredential(config.SecretId, config.SecretKey)
cpf := profile.NewClientProfile()
cpf.HttpProfile.Endpoint = "ai3d.tencentcloudapi.com"
client, err := tencent3d.NewClient(credential, config.Region, cpf)
if err != nil {
return fmt.Errorf("failed to create tencent 3D client: %v", err)
}
c.client = client
return nil
}
// SubmitJob 提交3D生成任务
func (c *Tencent3DClient) SubmitJob(params Tencent3DParams) (string, error) {
request := tencent3d.NewSubmitHunyuanTo3DJobRequest()
if params.Prompt != "" {
request.Prompt = tencentcloud.StringPtr(params.Prompt)
}
if params.ImageURL != "" {
request.ImageUrl = tencentcloud.StringPtr(params.ImageURL)
}
if params.ResultFormat != "" {
request.ResultFormat = tencentcloud.StringPtr(params.ResultFormat)
}
request.EnablePBR = tencentcloud.BoolPtr(params.EnablePBR)
if len(params.MultiViewImages) > 0 {
var viewImages []*tencent3d.ViewImage
for _, img := range params.MultiViewImages {
viewImage := &tencent3d.ViewImage{
ViewType: tencentcloud.StringPtr(img.ViewType),
ViewImageUrl: tencentcloud.StringPtr(img.ViewImageURL),
}
viewImages = append(viewImages, viewImage)
}
request.MultiViewImages = viewImages
}
response, err := c.client.SubmitHunyuanTo3DJob(request)
if err != nil {
return "", fmt.Errorf("failed to submit tencent 3D job: %v", err)
}
if response.Response.JobId == nil {
return "", fmt.Errorf("no job ID returned from tencent 3D API")
}
return *response.Response.JobId, nil
}
// QueryJob 查询任务状态
func (c *Tencent3DClient) QueryJob(jobId string) (*types.AI3DJobResult, error) {
request := tencent3d.NewQueryHunyuanTo3DJobRequest()
request.JobId = tencentcloud.StringPtr(jobId)
response, err := c.client.QueryHunyuanTo3DJob(request)
if err != nil {
return nil, fmt.Errorf("failed to query tencent 3D job: %v", err)
}
result := &types.AI3DJobResult{
JobId: jobId,
Status: *response.Response.Status,
Progress: 0,
}
// 根据状态设置进度
switch *response.Response.Status {
case "WAIT":
result.Status = "pending"
result.Progress = 10
case "RUN":
result.Status = "processing"
result.Progress = 50
case "DONE":
result.Status = "completed"
result.Progress = 100
// 处理结果文件
if len(response.Response.ResultFile3Ds) > 0 {
for _, file := range response.Response.ResultFile3Ds {
if file.Url != nil {
result.FileURL = *file.Url
}
if file.PreviewImageUrl != nil {
result.PreviewURL = *file.PreviewImageUrl
}
break // 取第一个文件
}
}
case "FAIL":
result.Status = "failed"
result.Progress = 0
if response.Response.ErrorMessage != nil {
result.ErrorMsg = *response.Response.ErrorMessage
}
}
return result, nil
}
// GetSupportedModels 获取支持的模型列表
func (c *Tencent3DClient) GetSupportedModels() []types.AI3DModel {
return []types.AI3DModel{
{Name: "Hunyuan3D-3", Power: 500, Formats: []string{"OBJ", "GLB", "STL", "USDZ", "FBX", "MP4"}, Desc: "Hunyuan3D 是腾讯混元团队推出的高质量 3D 生成模型,具备高保真度、细节丰富和高效生成的特点,可快速将文本或图像转换为逼真的 3D 物体。"},
}
}

View File

@@ -3,8 +3,10 @@ package jimeng
import (
"encoding/json"
"fmt"
"geekai/core/types"
"net/http"
"net/url"
"strings"
"github.com/volcengine/volc-sdk-golang/base"
"github.com/volcengine/volc-sdk-golang/service/visual"
@@ -13,14 +15,22 @@ import (
// Client 即梦API客户端
type Client struct {
visual *visual.Visual
config types.JimengConfig
}
// NewClient 创建即梦API客户端
func NewClient(accessKey, secretKey string) *Client {
func NewClient(sysConfig *types.SystemConfig) *Client {
client := &Client{}
client.UpdateConfig(sysConfig.Jimeng)
return client
}
func (c *Client) UpdateConfig(config types.JimengConfig) error {
// 使用官方SDK的visual实例
visualInstance := visual.NewInstance()
visualInstance.Client.SetAccessKey(accessKey)
visualInstance.Client.SetSecretKey(secretKey)
visualInstance.Client.SetAccessKey(config.AccessKey)
visualInstance.Client.SetSecretKey(config.SecretKey)
// 添加即梦AI专有的API配置
jimengApis := map[string]*base.ApiInfo{
@@ -55,9 +65,32 @@ func NewClient(accessKey, secretKey string) *Client {
visualInstance.Client.ApiInfoList[name] = info
}
return &Client{
visual: visualInstance,
c.config = config
c.visual = visualInstance
return c.testConnection()
}
// testConnection 测试即梦AI连接
func (c *Client) testConnection() error {
// 使用一个简单的查询任务来测试连接
testReq := &QueryTaskRequest{
ReqKey: "test_connection",
TaskId: "test_task_id_12345",
}
_, err := c.QueryTask(testReq)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "InvalidAccessKey") {
return fmt.Errorf("认证失败请检查AccessKey和SecretKey是否正确")
}
// 其他错误(如任务不存在)说明连接正常
return nil
}
return nil
}
// SubmitTask 提交异步任务

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"gorm.io/gorm"
@@ -16,8 +15,6 @@ import (
"geekai/store/model"
"geekai/utils"
"geekai/core/types"
"github.com/go-redis/redis/v8"
)
@@ -36,17 +33,8 @@ type Service struct {
}
// NewService 创建即梦服务
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
// 从数据库加载配置
var config model.Config
db.Where("name = ?", "Jimeng").First(&config)
var jimengConfig types.JimengConfig
if config.Id > 0 {
_ = utils.JsonDecode(config.Value, &jimengConfig)
}
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
ctx, cancel := context.WithCancel(context.Background())
return &Service{
db: db,
@@ -522,77 +510,3 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
}
return &job, nil
}
// testConnection 测试即梦AI连接
func (s *Service) testConnection(accessKey, secretKey string) error {
testClient := NewClient(accessKey, secretKey)
// 使用一个简单的查询任务来测试连接
testReq := &QueryTaskRequest{
ReqKey: "test_connection",
TaskId: "test_task_id_12345",
}
_, err := testClient.QueryTask(testReq)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "InvalidAccessKey") {
return fmt.Errorf("认证失败请检查AccessKey和SecretKey是否正确")
}
// 其他错误(如任务不存在)说明连接正常
return nil
}
return nil
}
// UpdateClientConfig 更新客户端配置
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
// 创建新的客户端
newClient := NewClient(accessKey, secretKey)
// 测试新客户端是否可用
err := s.testConnection(accessKey, secretKey)
if err != nil {
return err
}
// 更新客户端
s.client = newClient
return nil
}
var defaultPower = types.JimengPower{
TextToImage: 20,
ImageToImage: 20,
ImageEdit: 20,
ImageEffects: 20,
TextToVideo: 300,
ImageToVideo: 300,
}
// GetConfig 获取即梦AI配置
func (s *Service) GetConfig() *types.JimengConfig {
var config model.Config
err := s.db.Where("name", "jimeng").First(&config).Error
if err != nil {
// 如果配置不存在,返回默认配置
return &types.JimengConfig{
AccessKey: "",
SecretKey: "",
Power: defaultPower,
}
}
var jimengConfig types.JimengConfig
err = utils.JsonDecode(config.Value, &jimengConfig)
if err != nil {
return &types.JimengConfig{
AccessKey: "",
SecretKey: "",
Power: defaultPower,
}
}
return &jimengConfig
}

View File

@@ -154,6 +154,24 @@ func (s *MigrationService) MigrateConfigContent() error {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 3D生成配置
if err := s.saveConfig(types.ConfigKeyAI3D, map[string]any{
"tencent": map[string]any{
"access_key": "",
"secret_key": "",
"region": "",
"enabled": false,
"models": make([]types.AI3DModel, 0),
},
"gitee": map[string]any{
"api_key": "",
"enabled": false,
"models": make([]types.AI3DModel, 0),
},
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
return nil
}
@@ -161,6 +179,8 @@ func (s *MigrationService) MigrateConfigContent() error {
func (s *MigrationService) TableMigration() {
// 新数据表
s.db.AutoMigrate(&model.Moderation{})
s.db.AutoMigrate(&model.AI3DJob{})
// 订单字段整理
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")

View File

@@ -0,0 +1,23 @@
package model
import "time"
type AI3DJob struct {
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
UserId uint `gorm:"column:user_id;type:int(11);not null;comment:用户ID" json:"user_id"`
Type string `gorm:"column:type;type:varchar(20);not null;comment:API类型 (tencent/gitee)" json:"type"`
Power int `gorm:"column:power;type:int(11);not null;comment:消耗算力" json:"power"`
TaskId string `gorm:"column:task_id;type:varchar(100);comment:第三方任务ID" json:"task_id"`
FileURL string `gorm:"column:file_url;type:varchar(1024);comment:生成的3D模型文件地址" json:"file_url"`
PreviewURL string `gorm:"column:preview_url;type:varchar(1024);comment:预览图片地址" json:"preview_url"`
Model string `gorm:"column:model;type:varchar(50);comment:使用的3D模型类型" json:"model"`
Status string `gorm:"column:status;type:varchar(20);not null;default:pending;comment:任务状态" json:"status"`
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
Params string `gorm:"column:params;type:text;comment:任务参数(JSON格式)" json:"params"`
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
}
func (m *AI3DJob) TableName() string {
return "geekai_3d_jobs"
}

32
api/store/vo/ai3d_job.go Normal file
View File

@@ -0,0 +1,32 @@
package vo
type AI3DJob struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
Type string `json:"type"`
Power int `json:"power"`
TaskId string `json:"task_id"`
ImgURL string `json:"img_url"`
PreviewURL string `json:"preview_url"`
Model string `json:"model"`
Status string `json:"status"`
ErrMsg string `json:"err_msg"`
Params string `json:"params"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
type AI3DJobCreate struct {
Type string `json:"type" binding:"required"` // API类型 (tencent/gitee)
Model string `json:"model" binding:"required"` // 3D模型类型
Prompt string `json:"prompt"` // 文本提示词
ImageURL string `json:"image_url"` // 输入图片URL
Power int `json:"power" binding:"required"` // 消耗算力
}
type ThreeDJobList struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Total int `json:"total"`
List []AI3DJob `json:"list"`
}