mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-26 04:54:28 +08:00
3D生成服务已经完成
This commit is contained in:
150
api/service/ai3d/gitee_client.go
Normal file
150
api/service/ai3d/gitee_client.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package ai3d
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type Gitee3DClient struct {
|
||||
httpClient *req.Client
|
||||
config types.Gitee3DConfig
|
||||
apiURL string
|
||||
}
|
||||
|
||||
type Gitee3DParams struct {
|
||||
Prompt string `json:"prompt"` // 文本提示词
|
||||
ImageURL string `json:"image_url"` // 输入图片URL
|
||||
ResultFormat string `json:"result_format"` // 输出格式
|
||||
}
|
||||
|
||||
type Gitee3DResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
TaskID string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type Gitee3DQueryResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
Status string `json:"status"`
|
||||
Progress int `json:"progress"`
|
||||
ResultURL string `json:"result_url"`
|
||||
PreviewURL string `json:"preview_url"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func NewGitee3DClient(sysConfig *types.SystemConfig) *Gitee3DClient {
|
||||
return &Gitee3DClient{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
config: sysConfig.AI3D.Gitee,
|
||||
apiURL: "https://ai.gitee.com/v1/async/image-to-3d",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Gitee3DClient) UpdateConfig(config types.Gitee3DConfig) {
|
||||
c.config = config
|
||||
}
|
||||
|
||||
// SubmitJob 提交3D生成任务
|
||||
func (c *Gitee3DClient) SubmitJob(params Gitee3DParams) (string, error) {
|
||||
requestBody := map[string]any{
|
||||
"prompt": params.Prompt,
|
||||
"image_url": params.ImageURL,
|
||||
"result_format": params.ResultFormat,
|
||||
}
|
||||
|
||||
response, err := c.httpClient.R().
|
||||
SetHeader("Authorization", "Bearer "+c.config.APIKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(requestBody).
|
||||
Post(c.apiURL + "/async/image-to-3d")
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to submit gitee 3D job: %v", err)
|
||||
}
|
||||
|
||||
var giteeResp Gitee3DResponse
|
||||
if err := json.Unmarshal(response.Bytes(), &giteeResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse gitee response: %v", err)
|
||||
}
|
||||
|
||||
if giteeResp.Code != 0 {
|
||||
return "", fmt.Errorf("gitee API error: %s", giteeResp.Message)
|
||||
}
|
||||
|
||||
if giteeResp.Data.TaskID == "" {
|
||||
return "", fmt.Errorf("no task ID returned from gitee 3D API")
|
||||
}
|
||||
|
||||
return giteeResp.Data.TaskID, nil
|
||||
}
|
||||
|
||||
// QueryJob 查询任务状态
|
||||
func (c *Gitee3DClient) QueryJob(taskId string) (*types.AI3DJobResult, error) {
|
||||
response, err := c.httpClient.R().
|
||||
SetHeader("Authorization", "Bearer "+c.config.APIKey).
|
||||
Get(fmt.Sprintf("%s/task/%s/get", c.apiURL, taskId))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query gitee 3D job: %v", err)
|
||||
}
|
||||
|
||||
var giteeResp Gitee3DQueryResponse
|
||||
if err := json.Unmarshal(response.Bytes(), &giteeResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse gitee query response: %v", err)
|
||||
}
|
||||
|
||||
if giteeResp.Code != 0 {
|
||||
return nil, fmt.Errorf("gitee API error: %s", giteeResp.Message)
|
||||
}
|
||||
|
||||
result := &types.AI3DJobResult{
|
||||
JobId: taskId,
|
||||
Status: c.convertStatus(giteeResp.Data.Status),
|
||||
Progress: giteeResp.Data.Progress,
|
||||
}
|
||||
|
||||
// 根据状态设置结果
|
||||
switch giteeResp.Data.Status {
|
||||
case "completed":
|
||||
result.FileURL = giteeResp.Data.ResultURL
|
||||
result.PreviewURL = giteeResp.Data.PreviewURL
|
||||
case "failed":
|
||||
result.ErrorMsg = giteeResp.Data.ErrorMsg
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// convertStatus 转换Gitee状态到系统状态
|
||||
func (c *Gitee3DClient) convertStatus(giteeStatus string) string {
|
||||
switch giteeStatus {
|
||||
case "pending":
|
||||
return types.AI3DJobStatusPending
|
||||
case "processing":
|
||||
return types.AI3DJobStatusProcessing
|
||||
case "completed":
|
||||
return types.AI3DJobStatusCompleted
|
||||
case "failed":
|
||||
return types.AI3DJobStatusFailed
|
||||
default:
|
||||
return types.AI3DJobStatusPending
|
||||
}
|
||||
}
|
||||
|
||||
// GetSupportedModels 获取支持的模型列表
|
||||
func (c *Gitee3DClient) GetSupportedModels() []types.AI3DModel {
|
||||
return []types.AI3DModel{
|
||||
{Name: "Hunyuan3D-2", Power: 100, Formats: []string{"GLB"}, Desc: "Hunyuan3D-2 是腾讯混元团队推出的高质量 3D 生成模型,具备高保真度、细节丰富和高效生成的特点,可快速将文本或图像转换为逼真的 3D 物体。"},
|
||||
{Name: "Step1X-3D", Power: 55, Formats: []string{"GLB", "STL"}, Desc: "Step1X-3D 是一款由阶跃星辰(StepFun)与光影焕像(LightIllusions)联合研发并开源的高保真 3D 生成模型,专为高质量、可控的 3D 内容创作而设计。"},
|
||||
{Name: "Hi3DGen", Power: 35, Formats: []string{"GLB", "STL"}, Desc: "Hi3DGen 是一个 AI 工具,它可以把你上传的普通图片,智能转换成有“立体感”的图片(法线图),常用于制作 3D 效果,比如游戏建模、虚拟现实、动画制作等。"},
|
||||
}
|
||||
}
|
||||
327
api/service/ai3d/service.go
Normal file
327
api/service/ai3d/service.go
Normal file
@@ -0,0 +1,327 @@
|
||||
package ai3d
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// Service 3D生成服务
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
taskQueue *store.RedisQueue
|
||||
tencentClient *Tencent3DClient
|
||||
giteeClient *Gitee3DClient
|
||||
}
|
||||
|
||||
// NewService 创建3D生成服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, tencentClient *Tencent3DClient, giteeClient *Gitee3DClient) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("3D_Task_Queue", redisCli),
|
||||
tencentClient: tencentClient,
|
||||
giteeClient: giteeClient,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateJob 创建3D生成任务
|
||||
func (s *Service) CreateJob(userId uint, request vo.AI3DJobCreate) (*model.AI3DJob, error) {
|
||||
// 创建任务记录
|
||||
job := &model.AI3DJob{
|
||||
UserId: userId,
|
||||
Type: request.Type,
|
||||
Power: request.Power,
|
||||
Model: request.Model,
|
||||
Status: types.AI3DJobStatusPending,
|
||||
}
|
||||
|
||||
// 序列化参数
|
||||
params := map[string]any{
|
||||
"prompt": request.Prompt,
|
||||
"image_url": request.ImageURL,
|
||||
"model": request.Model,
|
||||
"power": request.Power,
|
||||
}
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
job.Params = string(paramsJSON)
|
||||
|
||||
// 保存到数据库
|
||||
if err := s.db.Create(job).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create 3D job: %v", err)
|
||||
}
|
||||
|
||||
// 将任务添加到队列
|
||||
s.PushTask(job)
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// PushTask 将任务添加到队列
|
||||
func (s *Service) PushTask(job *model.AI3DJob) {
|
||||
logger.Infof("add a new 3D task to the queue: %+v", job)
|
||||
if err := s.taskQueue.RPush(job); err != nil {
|
||||
logger.Errorf("push 3D task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run 启动任务处理器
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未完成的任务加载到队列
|
||||
var jobs []model.AI3DJob
|
||||
s.db.Where("status IN ?", []string{types.AI3DJobStatusPending, types.AI3DJobStatusProcessing}).Find(&jobs)
|
||||
for _, job := range jobs {
|
||||
s.PushTask(&job)
|
||||
}
|
||||
|
||||
logger.Info("Starting 3D job consumer...")
|
||||
go func() {
|
||||
for {
|
||||
var job model.AI3DJob
|
||||
err := s.taskQueue.LPop(&job)
|
||||
if err != nil {
|
||||
logger.Errorf("taking 3D task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("handle a new 3D task: %+v", job)
|
||||
go func() {
|
||||
if err := s.processJob(&job); err != nil {
|
||||
logger.Errorf("error processing 3D job: %v", err)
|
||||
s.updateJobStatus(&job, types.AI3DJobStatusFailed, 0, err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// processJob 处理3D任务
|
||||
func (s *Service) processJob(job *model.AI3DJob) error {
|
||||
// 更新状态为处理中
|
||||
s.updateJobStatus(job, types.AI3DJobStatusProcessing, 10, "")
|
||||
|
||||
// 解析参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.Params), ¶ms); err != nil {
|
||||
return fmt.Errorf("failed to parse job params: %v", err)
|
||||
}
|
||||
|
||||
var taskId string
|
||||
var err error
|
||||
|
||||
// 根据类型选择客户端
|
||||
switch job.Type {
|
||||
case "tencent":
|
||||
if s.tencentClient == nil {
|
||||
return fmt.Errorf("tencent 3D client not initialized")
|
||||
}
|
||||
tencentParams := Tencent3DParams{
|
||||
Prompt: s.getString(params, "prompt"),
|
||||
ImageURL: s.getString(params, "image_url"),
|
||||
ResultFormat: job.Model,
|
||||
EnablePBR: false,
|
||||
}
|
||||
taskId, err = s.tencentClient.SubmitJob(tencentParams)
|
||||
case "gitee":
|
||||
if s.giteeClient == nil {
|
||||
return fmt.Errorf("gitee 3D client not initialized")
|
||||
}
|
||||
giteeParams := Gitee3DParams{
|
||||
Prompt: s.getString(params, "prompt"),
|
||||
ImageURL: s.getString(params, "image_url"),
|
||||
ResultFormat: job.Model,
|
||||
}
|
||||
taskId, err = s.giteeClient.SubmitJob(giteeParams)
|
||||
default:
|
||||
return fmt.Errorf("unsupported 3D API type: %s", job.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to submit 3D job: %v", err)
|
||||
}
|
||||
|
||||
// 更新任务ID
|
||||
job.TaskId = taskId
|
||||
s.db.Model(job).Update("task_id", taskId)
|
||||
|
||||
// 开始轮询任务状态
|
||||
go s.pollJobStatus(job)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pollJobStatus 轮询任务状态
|
||||
func (s *Service) pollJobStatus(job *model.AI3DJob) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
result, err := s.queryJobStatus(job)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to query job status: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新进度
|
||||
s.updateJobStatus(job, result.Status, result.Progress, result.ErrorMsg)
|
||||
|
||||
// 如果任务完成或失败,停止轮询
|
||||
if result.Status == types.AI3DJobStatusCompleted || result.Status == types.AI3DJobStatusFailed {
|
||||
if result.Status == types.AI3DJobStatusCompleted {
|
||||
// 更新结果文件URL
|
||||
s.db.Model(job).Updates(map[string]interface{}{
|
||||
"img_url": result.FileURL,
|
||||
"preview_url": result.PreviewURL,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// queryJobStatus 查询任务状态
|
||||
func (s *Service) queryJobStatus(job *model.AI3DJob) (*types.AI3DJobResult, error) {
|
||||
switch job.Type {
|
||||
case "tencent":
|
||||
if s.tencentClient == nil {
|
||||
return nil, fmt.Errorf("tencent 3D client not initialized")
|
||||
}
|
||||
return s.tencentClient.QueryJob(job.TaskId)
|
||||
case "gitee":
|
||||
if s.giteeClient == nil {
|
||||
return nil, fmt.Errorf("gitee 3D client not initialized")
|
||||
}
|
||||
return s.giteeClient.QueryJob(job.TaskId)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported 3D API type: %s", job.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// updateJobStatus 更新任务状态
|
||||
func (s *Service) updateJobStatus(job *model.AI3DJob, status string, progress int, errMsg string) {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
if errMsg != "" {
|
||||
updates["err_msg"] = errMsg
|
||||
}
|
||||
|
||||
if err := s.db.Model(job).Updates(updates).Error; err != nil {
|
||||
logger.Errorf("failed to update job status: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetJobList 获取任务列表
|
||||
func (s *Service) GetJobList(userId uint, page, pageSize int) (*vo.Page, error) {
|
||||
var total int64
|
||||
var jobs []model.AI3DJob
|
||||
|
||||
// 查询总数
|
||||
if err := s.db.Model(&model.AI3DJob{}).Where("user_id = ?", userId).Count(&total).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 查询任务列表
|
||||
offset := (page - 1) * pageSize
|
||||
if err := s.db.Where("user_id = ?", userId).Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&jobs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为VO
|
||||
var jobList []vo.AI3DJob
|
||||
for _, job := range jobs {
|
||||
jobVO := vo.AI3DJob{
|
||||
Id: job.Id,
|
||||
UserId: job.UserId,
|
||||
Type: job.Type,
|
||||
Power: job.Power,
|
||||
TaskId: job.TaskId,
|
||||
ImgURL: job.FileURL,
|
||||
PreviewURL: job.PreviewURL,
|
||||
Model: job.Model,
|
||||
Status: job.Status,
|
||||
ErrMsg: job.ErrMsg,
|
||||
Params: job.Params,
|
||||
CreatedAt: job.CreatedAt.Unix(),
|
||||
UpdatedAt: job.UpdatedAt.Unix(),
|
||||
}
|
||||
jobList = append(jobList, jobVO)
|
||||
}
|
||||
|
||||
return &vo.Page{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Total: total,
|
||||
Items: jobList,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetJobById 根据ID获取任务
|
||||
func (s *Service) GetJobById(id uint) (*model.AI3DJob, error) {
|
||||
var job model.AI3DJob
|
||||
if err := s.db.Where("id = ?", id).First(&job).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// DeleteJob 删除任务
|
||||
func (s *Service) DeleteJob(id uint, userId uint) error {
|
||||
var job model.AI3DJob
|
||||
if err := s.db.Where("id = ? AND user_id = ?", id, userId).First(&job).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果任务已完成,退还算力
|
||||
if job.Status == types.AI3DJobStatusCompleted {
|
||||
// TODO: 实现算力退还逻辑
|
||||
logger2.GetLogger().Infof("should refund power %d for user %d", job.Power, userId)
|
||||
}
|
||||
|
||||
return s.db.Delete(&job).Error
|
||||
}
|
||||
|
||||
// GetSupportedModels 获取支持的模型列表
|
||||
func (s *Service) GetSupportedModels() map[string][]types.AI3DModel {
|
||||
|
||||
models := make(map[string][]types.AI3DModel)
|
||||
if s.tencentClient != nil {
|
||||
models["tencent"] = s.tencentClient.GetSupportedModels()
|
||||
}
|
||||
if s.giteeClient != nil {
|
||||
models["gitee"] = s.giteeClient.GetSupportedModels()
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func (s *Service) UpdateConfig(config types.AI3DConfig) {
|
||||
if s.tencentClient != nil {
|
||||
s.tencentClient.UpdateConfig(config.Tencent)
|
||||
}
|
||||
if s.giteeClient != nil {
|
||||
s.giteeClient.UpdateConfig(config.Gitee)
|
||||
}
|
||||
}
|
||||
|
||||
// getString 从map中获取字符串值
|
||||
func (s *Service) getString(params map[string]interface{}, key string) string {
|
||||
if val, ok := params[key]; ok {
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
158
api/service/ai3d/tencent_client.go
Normal file
158
api/service/ai3d/tencent_client.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package ai3d
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
|
||||
tencent3d "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ai3d/v20250513"
|
||||
tencentcloud "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||||
)
|
||||
|
||||
type Tencent3DClient struct {
|
||||
client *tencent3d.Client
|
||||
config types.Tencent3DConfig
|
||||
}
|
||||
|
||||
type Tencent3DParams struct {
|
||||
Prompt string `json:"prompt"` // 文本提示词
|
||||
ImageURL string `json:"image_url"` // 输入图片URL
|
||||
ResultFormat string `json:"result_format"` // 输出格式
|
||||
EnablePBR bool `json:"enable_pbr"` // 是否开启PBR材质
|
||||
MultiViewImages []ViewImage `json:"multi_view_images,omitempty"` // 多视角图片
|
||||
}
|
||||
|
||||
type ViewImage struct {
|
||||
ViewType string `json:"view_type"` // 视角类型 (left/right/back)
|
||||
ViewImageURL string `json:"view_image_url"` // 图片URL
|
||||
}
|
||||
|
||||
func NewTencent3DClient(sysConfig *types.SystemConfig) (*Tencent3DClient, error) {
|
||||
config := sysConfig.AI3D.Tencent
|
||||
credential := tencentcloud.NewCredential(config.SecretId, config.SecretKey)
|
||||
cpf := profile.NewClientProfile()
|
||||
cpf.HttpProfile.Endpoint = "ai3d.tencentcloudapi.com"
|
||||
|
||||
client, err := tencent3d.NewClient(credential, config.Region, cpf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tencent 3D client: %v", err)
|
||||
}
|
||||
|
||||
return &Tencent3DClient{
|
||||
client: client,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Tencent3DClient) UpdateConfig(config types.Tencent3DConfig) error {
|
||||
c.config = config
|
||||
credential := tencentcloud.NewCredential(config.SecretId, config.SecretKey)
|
||||
cpf := profile.NewClientProfile()
|
||||
cpf.HttpProfile.Endpoint = "ai3d.tencentcloudapi.com"
|
||||
|
||||
client, err := tencent3d.NewClient(credential, config.Region, cpf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tencent 3D client: %v", err)
|
||||
}
|
||||
c.client = client
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubmitJob 提交3D生成任务
|
||||
func (c *Tencent3DClient) SubmitJob(params Tencent3DParams) (string, error) {
|
||||
request := tencent3d.NewSubmitHunyuanTo3DJobRequest()
|
||||
|
||||
if params.Prompt != "" {
|
||||
request.Prompt = tencentcloud.StringPtr(params.Prompt)
|
||||
}
|
||||
|
||||
if params.ImageURL != "" {
|
||||
request.ImageUrl = tencentcloud.StringPtr(params.ImageURL)
|
||||
}
|
||||
|
||||
if params.ResultFormat != "" {
|
||||
request.ResultFormat = tencentcloud.StringPtr(params.ResultFormat)
|
||||
}
|
||||
|
||||
request.EnablePBR = tencentcloud.BoolPtr(params.EnablePBR)
|
||||
|
||||
if len(params.MultiViewImages) > 0 {
|
||||
var viewImages []*tencent3d.ViewImage
|
||||
for _, img := range params.MultiViewImages {
|
||||
viewImage := &tencent3d.ViewImage{
|
||||
ViewType: tencentcloud.StringPtr(img.ViewType),
|
||||
ViewImageUrl: tencentcloud.StringPtr(img.ViewImageURL),
|
||||
}
|
||||
viewImages = append(viewImages, viewImage)
|
||||
}
|
||||
request.MultiViewImages = viewImages
|
||||
}
|
||||
|
||||
response, err := c.client.SubmitHunyuanTo3DJob(request)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to submit tencent 3D job: %v", err)
|
||||
}
|
||||
|
||||
if response.Response.JobId == nil {
|
||||
return "", fmt.Errorf("no job ID returned from tencent 3D API")
|
||||
}
|
||||
|
||||
return *response.Response.JobId, nil
|
||||
}
|
||||
|
||||
// QueryJob 查询任务状态
|
||||
func (c *Tencent3DClient) QueryJob(jobId string) (*types.AI3DJobResult, error) {
|
||||
request := tencent3d.NewQueryHunyuanTo3DJobRequest()
|
||||
request.JobId = tencentcloud.StringPtr(jobId)
|
||||
|
||||
response, err := c.client.QueryHunyuanTo3DJob(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tencent 3D job: %v", err)
|
||||
}
|
||||
|
||||
result := &types.AI3DJobResult{
|
||||
JobId: jobId,
|
||||
Status: *response.Response.Status,
|
||||
Progress: 0,
|
||||
}
|
||||
|
||||
// 根据状态设置进度
|
||||
switch *response.Response.Status {
|
||||
case "WAIT":
|
||||
result.Status = "pending"
|
||||
result.Progress = 10
|
||||
case "RUN":
|
||||
result.Status = "processing"
|
||||
result.Progress = 50
|
||||
case "DONE":
|
||||
result.Status = "completed"
|
||||
result.Progress = 100
|
||||
// 处理结果文件
|
||||
if len(response.Response.ResultFile3Ds) > 0 {
|
||||
for _, file := range response.Response.ResultFile3Ds {
|
||||
if file.Url != nil {
|
||||
result.FileURL = *file.Url
|
||||
}
|
||||
if file.PreviewImageUrl != nil {
|
||||
result.PreviewURL = *file.PreviewImageUrl
|
||||
}
|
||||
break // 取第一个文件
|
||||
}
|
||||
}
|
||||
case "FAIL":
|
||||
result.Status = "failed"
|
||||
result.Progress = 0
|
||||
if response.Response.ErrorMessage != nil {
|
||||
result.ErrorMsg = *response.Response.ErrorMessage
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetSupportedModels 获取支持的模型列表
|
||||
func (c *Tencent3DClient) GetSupportedModels() []types.AI3DModel {
|
||||
return []types.AI3DModel{
|
||||
{Name: "Hunyuan3D-3", Power: 500, Formats: []string{"OBJ", "GLB", "STL", "USDZ", "FBX", "MP4"}, Desc: "Hunyuan3D 是腾讯混元团队推出的高质量 3D 生成模型,具备高保真度、细节丰富和高效生成的特点,可快速将文本或图像转换为逼真的 3D 物体。"},
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,10 @@ package jimeng
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
@@ -13,14 +15,22 @@ import (
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
visual *visual.Visual
|
||||
config types.JimengConfig
|
||||
}
|
||||
|
||||
// NewClient 创建即梦API客户端
|
||||
func NewClient(accessKey, secretKey string) *Client {
|
||||
func NewClient(sysConfig *types.SystemConfig) *Client {
|
||||
|
||||
client := &Client{}
|
||||
client.UpdateConfig(sysConfig.Jimeng)
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *Client) UpdateConfig(config types.JimengConfig) error {
|
||||
// 使用官方SDK的visual实例
|
||||
visualInstance := visual.NewInstance()
|
||||
visualInstance.Client.SetAccessKey(accessKey)
|
||||
visualInstance.Client.SetSecretKey(secretKey)
|
||||
visualInstance.Client.SetAccessKey(config.AccessKey)
|
||||
visualInstance.Client.SetSecretKey(config.SecretKey)
|
||||
|
||||
// 添加即梦AI专有的API配置
|
||||
jimengApis := map[string]*base.ApiInfo{
|
||||
@@ -55,9 +65,32 @@ func NewClient(accessKey, secretKey string) *Client {
|
||||
visualInstance.Client.ApiInfoList[name] = info
|
||||
}
|
||||
|
||||
return &Client{
|
||||
visual: visualInstance,
|
||||
c.config = config
|
||||
c.visual = visualInstance
|
||||
|
||||
return c.testConnection()
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (c *Client) testConnection() error {
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := c.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -16,8 +15,6 @@ import (
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
@@ -36,17 +33,8 @@ type Service struct {
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
// 从数据库加载配置
|
||||
var config model.Config
|
||||
db.Where("name = ?", "Jimeng").First(&config)
|
||||
var jimengConfig types.JimengConfig
|
||||
if config.Id > 0 {
|
||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
}
|
||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
db: db,
|
||||
@@ -522,77 +510,3 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||
testClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := testClient.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateClientConfig 更新客户端配置
|
||||
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
|
||||
// 创建新的客户端
|
||||
newClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 测试新客户端是否可用
|
||||
err := s.testConnection(accessKey, secretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新客户端
|
||||
s.client = newClient
|
||||
return nil
|
||||
}
|
||||
|
||||
var defaultPower = types.JimengPower{
|
||||
TextToImage: 20,
|
||||
ImageToImage: 20,
|
||||
ImageEdit: 20,
|
||||
ImageEffects: 20,
|
||||
TextToVideo: 300,
|
||||
ImageToVideo: 300,
|
||||
}
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (s *Service) GetConfig() *types.JimengConfig {
|
||||
var config model.Config
|
||||
err := s.db.Where("name", "jimeng").First(&config).Error
|
||||
if err != nil {
|
||||
// 如果配置不存在,返回默认配置
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
var jimengConfig types.JimengConfig
|
||||
err = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
if err != nil {
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
return &jimengConfig
|
||||
}
|
||||
|
||||
@@ -154,6 +154,24 @@ func (s *MigrationService) MigrateConfigContent() error {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 3D生成配置
|
||||
if err := s.saveConfig(types.ConfigKeyAI3D, map[string]any{
|
||||
"tencent": map[string]any{
|
||||
"access_key": "",
|
||||
"secret_key": "",
|
||||
"region": "",
|
||||
"enabled": false,
|
||||
"models": make([]types.AI3DModel, 0),
|
||||
},
|
||||
"gitee": map[string]any{
|
||||
"api_key": "",
|
||||
"enabled": false,
|
||||
"models": make([]types.AI3DModel, 0),
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -161,6 +179,8 @@ func (s *MigrationService) MigrateConfigContent() error {
|
||||
func (s *MigrationService) TableMigration() {
|
||||
// 新数据表
|
||||
s.db.AutoMigrate(&model.Moderation{})
|
||||
s.db.AutoMigrate(&model.AI3DJob{})
|
||||
|
||||
// 订单字段整理
|
||||
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
|
||||
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
|
||||
|
||||
Reference in New Issue
Block a user