diff --git a/api/core/types/config.go b/api/core/types/config.go index 01b6bc02..bd57b31d 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -43,9 +43,16 @@ type SmtpConfig struct { } type ApiConfig struct { - ApiURL string - AppId string - Token string + ApiURL string + AppId string + Token string + JimengConfig JimengConfig // 即梦AI配置 +} + +// JimengConfig 即梦AI配置 +type JimengConfig struct { + AccessKey string // 火山引擎AccessKey + SecretKey string // 火山引擎SecretKey } type AlipayConfig struct { @@ -170,7 +177,7 @@ type SystemConfig struct { EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码 EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表 - TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的大模型 id + AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位:MB } diff --git a/api/handler/admin/jimeng_handler.go b/api/handler/admin/jimeng_handler.go new file mode 100644 index 00000000..7580a50f --- /dev/null +++ b/api/handler/admin/jimeng_handler.go @@ -0,0 +1,177 @@ +package admin + +import ( + "strconv" + + "geekai/core" + "geekai/handler" + "geekai/service/jimeng" + "geekai/store/model" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" +) + +// AdminJimengHandler 管理后台即梦AI处理器 +type AdminJimengHandler struct { + handler.BaseHandler + jimengService *jimeng.Service +} + +// NewAdminJimengHandler 创建管理后台即梦AI处理器 +func NewAdminJimengHandler(app *core.AppServer, jimengService *jimeng.Service) *AdminJimengHandler { + return &AdminJimengHandler{ + BaseHandler: handler.BaseHandler{App: app}, + jimengService: jimengService, + } +} + +// Jobs 获取任务列表 +func (h *AdminJimengHandler) Jobs(c *gin.Context) { + page := h.GetInt(c, "page", 1) + pageSize := h.GetInt(c, "page_size", 20) + userId := h.GetInt(c, "user_id", 0) + taskType := h.GetTrim(c, "type") + status := h.GetTrim(c, "status") + + var tasks []model.JimengJob + var total int64 + + session := h.DB.Model(&model.JimengJob{}) + + // 构建查询条件 + if userId > 0 { + session = session.Where("user_id = ?", userId) + } + if taskType != "" { + session = session.Where("type = ?", taskType) + } + if status != "" { + session = session.Where("status = ?", status) + } + + // 获取总数 + err := session.Count(&total).Error + if err != nil { + resp.ERROR(c, "获取任务数量失败") + return + } + + // 获取数据 + offset := (page - 1) * pageSize + err = session.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&tasks).Error + if err != nil { + resp.ERROR(c, "获取任务列表失败") + return + } + + resp.SUCCESS(c, gin.H{ + "jobs": tasks, + "total": total, + "page": page, + "page_size": pageSize, + }) +} + +// JobDetail 获取任务详情 +func (h *AdminJimengHandler) JobDetail(c *gin.Context) { + idStr := c.Param("id") + jobId, err := strconv.ParseUint(idStr, 10, 32) + if err != nil { + resp.ERROR(c, "参数错误") + return + } + + var job model.JimengJob + err = h.DB.Where("id = ?", jobId).First(&job).Error + if err != nil { + resp.ERROR(c, "任务不存在") + return + } + + resp.SUCCESS(c, job) +} + +// Remove 删除任务 +func (h *AdminJimengHandler) Remove(c *gin.Context) { + idStr := c.Param("id") + jobId, err := strconv.ParseUint(idStr, 10, 32) + if err != nil { + resp.ERROR(c, "参数错误") + return + } + + err = h.DB.Where("id = ?", jobId).Delete(&model.JimengJob{}).Error + if err != nil { + resp.ERROR(c, "删除任务失败") + return + } + + resp.SUCCESS(c, gin.H{}) +} + +// BatchRemove 批量删除任务 +func (h *AdminJimengHandler) BatchRemove(c *gin.Context) { + var req struct { + JobIds []uint `json:"job_ids" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + resp.ERROR(c, "参数错误") + return + } + + result := h.DB.Where("id IN ?", req.JobIds).Delete(&model.JimengJob{}) + if result.Error != nil { + resp.ERROR(c, "批量删除失败") + return + } + + resp.SUCCESS(c, gin.H{ + "message": "批量删除成功", + "deleted_count": result.RowsAffected, + }) +} + +// Stats 获取统计信息 +func (h *AdminJimengHandler) Stats(c *gin.Context) { + type StatResult struct { + Status string `json:"status"` + Count int64 `json:"count"` + } + + var stats []StatResult + err := h.DB.Model(&model.JimengJob{}). + Select("status, COUNT(*) as count"). + Group("status"). + Find(&stats).Error + if err != nil { + resp.ERROR(c, "获取统计信息失败") + return + } + + // 整理统计数据 + result := gin.H{ + "totalTasks": int64(0), + "completedTasks": int64(0), + "processingTasks": int64(0), + "failedTasks": int64(0), + "pendingTasks": int64(0), + } + + for _, stat := range stats { + result["totalTasks"] = result["totalTasks"].(int64) + stat.Count + switch stat.Status { + case "completed": + result["completedTasks"] = stat.Count + case "processing": + result["processingTasks"] = stat.Count + case "failed": + result["failedTasks"] = stat.Count + case "pending": + result["pendingTasks"] = stat.Count + } + } + + resp.SUCCESS(c, result) +} \ No newline at end of file diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index 256bd07a..0c7bc037 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -77,7 +77,7 @@ func (h *DallJobHandler) Image(c *gin.Context) { Quality: data.Quality, Size: data.Size, Style: data.Style, - TranslateModelId: h.App.SysConfig.TranslateModelId, + TranslateModelId: h.App.SysConfig.AssistantModelId, Power: chatModel.Power, } job := model.DallJob{ diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go index 9cf59a8a..fb6d6cd4 100644 --- a/api/handler/function_handler.go +++ b/api/handler/function_handler.go @@ -213,7 +213,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { Prompt: prompt, ModelId: 0, ModelName: "dall-e-3", - TranslateModelId: h.App.SysConfig.TranslateModelId, + TranslateModelId: h.App.SysConfig.AssistantModelId, N: 1, Quality: "standard", Size: "1024x1024", @@ -265,27 +265,27 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - + // 从参数中获取搜索关键词 keyword, ok := params["keyword"].(string) if !ok || keyword == "" { resp.ERROR(c, "搜索关键词不能为空") return } - + // 从参数中获取最大页数,默认为1页 maxPages := 1 if pages, ok := params["max_pages"].(float64); ok { maxPages = int(pages) } - + // 获取用户ID userID, ok := params["user_id"].(float64) if !ok { resp.ERROR(c, "用户ID不能为空") return } - + // 查询用户信息 var user model.User res := h.DB.Where("id = ?", int(userID)).First(&user) @@ -293,21 +293,21 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) { resp.ERROR(c, "用户不存在") return } - + // 检查用户算力是否足够 searchPower := 1 // 每次搜索消耗1点算力 if user.Power < searchPower { resp.ERROR(c, "算力不足,无法执行网络搜索") return } - + // 执行网络搜索 searchResults, err := crawler.SearchWeb(keyword, maxPages) if err != nil { resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err)) return } - + // 扣减用户算力 err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{ Type: types.PowerConsume, @@ -318,7 +318,7 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) { resp.ERROR(c, "扣减算力失败:"+err.Error()) return } - + // 返回搜索结果 resp.SUCCESS(c, searchResults) } diff --git a/api/handler/jimeng_handler.go b/api/handler/jimeng_handler.go new file mode 100644 index 00000000..ba832830 --- /dev/null +++ b/api/handler/jimeng_handler.go @@ -0,0 +1,639 @@ +package handler + +import ( + "fmt" + "strconv" + "time" + + "geekai/core" + "geekai/core/types" + "geekai/service/jimeng" + "geekai/store/model" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +// JimengHandler 即梦AI处理器 +type JimengHandler struct { + BaseHandler + jimengService *jimeng.Service +} + +// NewJimengHandler 创建即梦AI处理器 +func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service) *JimengHandler { + return &JimengHandler{ + BaseHandler: BaseHandler{App: app}, + jimengService: jimengService, + } +} + +// TextToImage 文生图 +func (h *JimengHandler) TextToImage(c *gin.Context) { + var req struct { + Prompt string `json:"prompt" binding:"required"` + Seed int64 `json:"seed"` + Scale float64 `json:"scale"` + Width int `json:"width"` + Height int `json:"height"` + UsePreLLM bool `json:"use_pre_llm"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // 获取当前用户 + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + // 检查用户算力 + if user.Power < 20 { // 文生图消耗20算力 + resp.ERROR(c, "算力不足") + return + } + + // 设置默认参数 + if req.Scale == 0 { + req.Scale = 2.5 + } + if req.Width == 0 { + req.Width = 1328 + } + if req.Height == 0 { + req.Height = 1328 + } + if req.Seed == 0 { + req.Seed = -1 + } + + // 构建任务参数 + params := map[string]interface{}{ + "seed": req.Seed, + "scale": req.Scale, + "width": req.Width, + "height": req.Height, + "use_pre_llm": req.UsePreLLM, + } + + // 创建任务 + taskReq := &jimeng.CreateTaskRequest{ + Type: model.JimengJobTypeTextToImage, + Prompt: req.Prompt, + Params: params, + ReqKey: model.ReqKeyTextToImage, + Power: 20, + } + + job, err := h.jimengService.CreateTask(user.Id, taskReq) + if err != nil { + logger.Errorf("create jimeng text to image task failed: %v", err) + resp.ERROR(c, "创建任务失败") + return + } + + // 扣除用户算力 + h.subUserPower(user.Id, 20, model.PowerLog{ + Type: types.PowerConsume, + Model: "即梦文生图", + Remark: fmt.Sprintf("任务ID:%d", job.Id), + }) + + resp.SUCCESS(c, job) +} + +// ImageToImagePortrait 图生图人像写真 +func (h *JimengHandler) ImageToImagePortrait(c *gin.Context) { + var req struct { + ImageInput string `json:"image_input" binding:"required"` + Prompt string `json:"prompt"` + Width int `json:"width"` + Height int `json:"height"` + Gpen float64 `json:"gpen"` + Skin float64 `json:"skin"` + SkinUnifi float64 `json:"skin_unifi"` + GenMode string `json:"gen_mode"` + Seed int64 `json:"seed"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + resp.ERROR(c, "参数错误: "+err.Error()) + return + } + + // 获取当前用户 + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + // 检查用户算力 + if user.Power < 30 { // 图生图消耗30算力 + resp.ERROR(c, "算力不足") + return + } + + // 设置默认参数 + if req.Width == 0 { + req.Width = 1328 + } + if req.Height == 0 { + req.Height = 1328 + } + if req.Gpen == 0 { + req.Gpen = 0.4 + } + if req.Skin == 0 { + req.Skin = 0.3 + } + if req.GenMode == "" { + if req.Prompt != "" { + req.GenMode = jimeng.GenModeCreative + } else { + req.GenMode = jimeng.GenModeReference + } + } + if req.Seed == 0 { + req.Seed = -1 + } + if req.Prompt == "" { + req.Prompt = "演唱会现场的合照,闪光灯拍摄" + } + + // 构建任务参数 + params := map[string]interface{}{ + "image_input": req.ImageInput, + "width": req.Width, + "height": req.Height, + "gpen": req.Gpen, + "skin": req.Skin, + "skin_unifi": req.SkinUnifi, + "gen_mode": req.GenMode, + "seed": req.Seed, + } + + // 创建任务 + taskReq := &jimeng.CreateTaskRequest{ + Type: model.JimengJobTypeImageToImagePortrait, + Prompt: req.Prompt, + Params: params, + ReqKey: model.ReqKeyImageToImagePortrait, + Power: 30, + } + + job, err := h.jimengService.CreateTask(user.Id, taskReq) + if err != nil { + logger.Errorf("create jimeng image to image portrait task failed: %v", err) + resp.ERROR(c, "创建任务失败") + return + } + + // 扣除用户算力 + h.subUserPower(user.Id, 30, model.PowerLog{ + Type: types.PowerConsume, + Model: "即梦图生图", + Remark: fmt.Sprintf("任务ID:%d", job.Id), + }) + + resp.SUCCESS(c, job) +} + +// ImageEdit 图像编辑 +func (h *JimengHandler) ImageEdit(c *gin.Context) { + var req struct { + ImageUrls []string `json:"image_urls"` + BinaryDataBase64 []string `json:"binary_data_base64"` + Prompt string `json:"prompt" binding:"required"` + Seed int64 `json:"seed"` + Scale float64 `json:"scale"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + resp.ERROR(c, "参数错误: "+err.Error()) + return + } + + if len(req.ImageUrls) == 0 && len(req.BinaryDataBase64) == 0 { + resp.ERROR(c, "请提供图片URL或Base64数据") + return + } + + // 获取当前用户 + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + // 检查用户算力 + if user.Power < 25 { // 图像编辑消耗25算力 + resp.ERROR(c, "算力不足") + return + } + + // 设置默认参数 + if req.Scale == 0 { + req.Scale = 0.5 + } + if req.Seed == 0 { + req.Seed = -1 + } + + // 构建任务参数 + params := map[string]interface{}{ + "seed": req.Seed, + "scale": req.Scale, + } + if len(req.ImageUrls) > 0 { + params["image_urls"] = req.ImageUrls + } + if len(req.BinaryDataBase64) > 0 { + params["binary_data_base64"] = req.BinaryDataBase64 + } + + // 创建任务 + taskReq := &jimeng.CreateTaskRequest{ + Type: model.JimengJobTypeImageEdit, + Prompt: req.Prompt, + Params: params, + ReqKey: model.ReqKeyImageEdit, + Power: 25, + } + + job, err := h.jimengService.CreateTask(user.Id, taskReq) + if err != nil { + logger.Errorf("create jimeng image edit task failed: %v", err) + resp.ERROR(c, "创建任务失败") + return + } + + // 扣除用户算力 + h.subUserPower(user.Id, 25, model.PowerLog{ + Type: types.PowerConsume, + Model: "即梦图像编辑", + Remark: fmt.Sprintf("任务ID:%d", job.Id), + }) + + resp.SUCCESS(c, job) +} + +// ImageEffects 图像特效 +func (h *JimengHandler) ImageEffects(c *gin.Context) { + var req struct { + ImageInput1 string `json:"image_input1" binding:"required"` + TemplateId string `json:"template_id" binding:"required"` + Width int `json:"width"` + Height int `json:"height"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + resp.ERROR(c, "参数错误: "+err.Error()) + return + } + + // 获取当前用户 + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + // 检查用户算力 + if user.Power < 15 { // 图像特效消耗15算力 + resp.ERROR(c, "算力不足") + return + } + + // 设置默认参数 + if req.Width == 0 { + req.Width = 1328 + } + if req.Height == 0 { + req.Height = 1328 + } + + // 构建任务参数 + params := map[string]interface{}{ + "image_input1": req.ImageInput1, + "template_id": req.TemplateId, + "width": req.Width, + "height": req.Height, + } + + // 创建任务 + taskReq := &jimeng.CreateTaskRequest{ + Type: model.JimengJobTypeImageEffects, + Prompt: "", + Params: params, + ReqKey: model.ReqKeyImageEffects, + Power: 15, + } + + job, err := h.jimengService.CreateTask(user.Id, taskReq) + if err != nil { + logger.Errorf("create jimeng image effects task failed: %v", err) + resp.ERROR(c, "创建任务失败") + return + } + + // 扣除用户算力 + h.subUserPower(user.Id, 15, model.PowerLog{ + Type: types.PowerConsume, + Model: "即梦图像特效", + Remark: fmt.Sprintf("任务ID:%d", job.Id), + }) + + resp.SUCCESS(c, job) +} + +// TextToVideo 文生视频 +func (h *JimengHandler) TextToVideo(c *gin.Context) { + var req struct { + Prompt string `json:"prompt" binding:"required"` + Seed int64 `json:"seed"` + AspectRatio string `json:"aspect_ratio"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + resp.ERROR(c, "参数错误: "+err.Error()) + return + } + + // 获取当前用户 + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + // 检查用户算力 + if user.Power < 100 { // 文生视频消耗100算力 + resp.ERROR(c, "算力不足") + return + } + + // 设置默认参数 + if req.Seed == 0 { + req.Seed = -1 + } + if req.AspectRatio == "" { + req.AspectRatio = jimeng.AspectRatio16_9 + } + + // 构建任务参数 + params := map[string]interface{}{ + "seed": req.Seed, + "aspect_ratio": req.AspectRatio, + } + + // 创建任务 + taskReq := &jimeng.CreateTaskRequest{ + Type: model.JimengJobTypeTextToVideo, + Prompt: req.Prompt, + Params: params, + ReqKey: model.ReqKeyTextToVideo, + Power: 100, + } + + job, err := h.jimengService.CreateTask(user.Id, taskReq) + if err != nil { + logger.Errorf("create jimeng text to video task failed: %v", err) + resp.ERROR(c, "创建任务失败") + return + } + + // 扣除用户算力 + h.subUserPower(user.Id, 100, model.PowerLog{ + Type: types.PowerConsume, + Model: "即梦文生视频", + Remark: fmt.Sprintf("任务ID:%d", job.Id), + }) + + resp.SUCCESS(c, job) +} + +// ImageToVideo 图生视频 +func (h *JimengHandler) ImageToVideo(c *gin.Context) { + var req struct { + ImageUrls []string `json:"image_urls"` + BinaryDataBase64 []string `json:"binary_data_base64"` + Prompt string `json:"prompt"` + Seed int64 `json:"seed"` + AspectRatio string `json:"aspect_ratio" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + resp.ERROR(c, "参数错误: "+err.Error()) + return + } + + if len(req.ImageUrls) == 0 && len(req.BinaryDataBase64) == 0 { + resp.ERROR(c, "请提供图片URL或Base64数据") + return + } + + // 获取当前用户 + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + // 检查用户算力 + if user.Power < 120 { // 图生视频消耗120算力 + resp.ERROR(c, "算力不足") + return + } + + // 设置默认参数 + if req.Seed == 0 { + req.Seed = -1 + } + + // 构建任务参数 + params := map[string]interface{}{ + "seed": req.Seed, + "aspect_ratio": req.AspectRatio, + } + if len(req.ImageUrls) > 0 { + params["image_urls"] = req.ImageUrls + } + if len(req.BinaryDataBase64) > 0 { + params["binary_data_base64"] = req.BinaryDataBase64 + } + + // 创建任务 + taskReq := &jimeng.CreateTaskRequest{ + Type: model.JimengJobTypeImageToVideo, + Prompt: req.Prompt, + Params: params, + ReqKey: model.ReqKeyImageToVideo, + Power: 120, + } + + job, err := h.jimengService.CreateTask(user.Id, taskReq) + if err != nil { + logger.Errorf("create jimeng image to video task failed: %v", err) + resp.ERROR(c, "创建任务失败") + return + } + + // 扣除用户算力 + h.subUserPower(user.Id, 120, model.PowerLog{ + Type: types.PowerConsume, + Model: "即梦图生视频", + Remark: fmt.Sprintf("任务ID:%d", job.Id), + }) + + resp.SUCCESS(c, job) +} + +// Jobs 获取任务列表 +func (h *JimengHandler) Jobs(c *gin.Context) { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + page := h.GetInt(c, "page", 1) + pageSize := h.GetInt(c, "page_size", 20) + + jobs, total, err := h.jimengService.GetUserJobs(user.Id, page, pageSize) + if err != nil { + logger.Errorf("get user jimeng jobs failed: %v", err) + resp.ERROR(c, "获取任务列表失败") + return + } + + resp.SUCCESS(c, gin.H{ + "jobs": jobs, + "total": total, + "page": page, + "page_size": pageSize, + }) +} + +// PendingCount 获取未完成任务数量 +func (h *JimengHandler) PendingCount(c *gin.Context) { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + count, err := h.jimengService.GetPendingTaskCount(user.Id) + if err != nil { + logger.Errorf("get pending task count failed: %v", err) + resp.ERROR(c, "获取待处理任务数量失败") + return + } + + resp.SUCCESS(c, gin.H{"count": count}) +} + +// Remove 删除任务 +func (h *JimengHandler) Remove(c *gin.Context) { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + jobId := h.GetInt(c, "id", 0) + if jobId == 0 { + resp.ERROR(c, "参数错误") + return + } + + if err := h.jimengService.DeleteJob(uint(jobId), user.Id); err != nil { + logger.Errorf("delete jimeng job failed: %v", err) + resp.ERROR(c, "删除任务失败") + return + } + + resp.SUCCESS(c, gin.H{}) +} + +// Retry 重试任务 +func (h *JimengHandler) Retry(c *gin.Context) { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + jobIdStr := c.Param("id") + jobId, err := strconv.ParseUint(jobIdStr, 10, 32) + if err != nil { + resp.ERROR(c, "参数错误") + return + } + + // 检查任务是否存在且属于当前用户 + job, err := h.jimengService.GetJob(uint(jobId)) + if err != nil { + resp.ERROR(c, "任务不存在") + return + } + + if job.UserId != user.Id { + resp.ERROR(c, "无权限操作") + return + } + + // 只有失败的任务才能重试 + if job.Status != model.JimengJobStatusFailed { + resp.ERROR(c, "只有失败的任务才能重试") + return + } + + // 重置任务状态 + if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JimengJobStatusPending, ""); err != nil { + logger.Errorf("reset job status failed: %v", err) + resp.ERROR(c, "重置任务状态失败") + return + } + + // 重新推送到队列 + task := map[string]interface{}{ + "job_id": jobId, + "type": job.Type, + } + if err := h.jimengService.PushTaskToQueue(task); err != nil { + logger.Errorf("push retry task to queue failed: %v", err) + resp.ERROR(c, "推送重试任务失败") + return + } + + resp.SUCCESS(c, gin.H{"message": "重试任务已提交"}) +} + +// subUserPower 扣除用户算力 +func (h *JimengHandler) subUserPower(userId uint, power int, powerLog model.PowerLog) { + session := h.DB.Session(&gorm.Session{}) + + // 更新用户算力 + if err := session.Model(&model.User{}).Where("id = ?", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error; err != nil { + logger.Errorf("update user power failed: %v", err) + return + } + + // 记录算力消费日志 + powerLog.UserId = userId + powerLog.Amount = power + powerLog.Mark = types.PowerSub + powerLog.CreatedAt = time.Now() + if err := session.Create(&powerLog).Error; err != nil { + logger.Errorf("create power log failed: %v", err) + return + } + + session.Commit() +} diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index b1f9fe96..fc522a7c 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -160,7 +160,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { UserId: userId, ImgArr: data.ImgArr, Mode: h.App.SysConfig.MjMode, - TranslateModelId: h.App.SysConfig.TranslateModelId, + TranslateModelId: h.App.SysConfig.AssistantModelId, } job := model.MidJourneyJob{ Type: data.TaskType, diff --git a/api/handler/prompt_handler.go b/api/handler/prompt_handler.go index 31fecc9a..100099b4 100644 --- a/api/handler/prompt_handler.go +++ b/api/handler/prompt_handler.go @@ -48,7 +48,7 @@ func (h *PromptHandler) Lyric(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId) + content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) if err != nil { resp.ERROR(c, err.Error()) return @@ -79,7 +79,7 @@ func (h *PromptHandler) Image(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.TranslateModelId) + content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) if err != nil { resp.ERROR(c, err.Error()) return @@ -108,7 +108,7 @@ func (h *PromptHandler) Video(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId) + content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) if err != nil { resp.ERROR(c, err.Error()) return @@ -158,9 +158,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) { } func (h *PromptHandler) getPromptModel() string { - if h.App.SysConfig.TranslateModelId > 0 { + if h.App.SysConfig.AssistantModelId > 0 { var chatModel model.ChatModel - h.DB.Where("id", h.App.SysConfig.TranslateModelId).First(&chatModel) + h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel) return chatModel.Value } return "gpt-4o" diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index c8358d08..f2eaf974 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -131,7 +131,7 @@ func (h *SdJobHandler) Image(c *gin.Context) { HdSteps: data.HdSteps, }, UserId: userId, - TranslateModelId: h.App.SysConfig.TranslateModelId, + TranslateModelId: h.App.SysConfig.AssistantModelId, } job := model.SdJob{ diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index a3aff209..6543a8c2 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -85,7 +85,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { Type: types.VideoLuma, Prompt: data.Prompt, Params: params, - TranslateModelId: h.App.SysConfig.TranslateModelId, + TranslateModelId: h.App.SysConfig.AssistantModelId, } // 插入数据库 job := model.VideoJob{ @@ -181,7 +181,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) { Type: types.VideoKeLing, Prompt: data.Prompt, Params: params, - TranslateModelId: h.App.SysConfig.TranslateModelId, + TranslateModelId: h.App.SysConfig.AssistantModelId, Channel: data.Channel, } // 插入数据库 diff --git a/api/main.go b/api/main.go index ba53f64d..f92a1d74 100644 --- a/api/main.go +++ b/api/main.go @@ -17,6 +17,7 @@ import ( logger2 "geekai/logger" "geekai/service" "geekai/service/dalle" + "geekai/service/jimeng" "geekai/service/mj" "geekai/service/oss" "geekai/service/payment" @@ -140,6 +141,7 @@ func main() { fx.Provide(handler.NewProductHandler), fx.Provide(handler.NewConfigHandler), fx.Provide(handler.NewPowerLogHandler), + fx.Provide(handler.NewJimengHandler), fx.Provide(admin.NewConfigHandler), fx.Provide(admin.NewAdminHandler), @@ -153,6 +155,9 @@ func main() { fx.Provide(admin.NewOrderHandler), fx.Provide(admin.NewChatHandler), fx.Provide(admin.NewPowerLogHandler), + fx.Provide(func(app *core.AppServer, service *jimeng.Service) *admin.AdminJimengHandler { + return admin.NewAdminJimengHandler(app, service) + }), // 创建服务 fx.Provide(sms.NewSendServiceManager), @@ -203,6 +208,17 @@ func main() { s.SyncTaskProgress() s.DownloadFiles() }), + + // 即梦AI 服务 + fx.Provide(func(config *types.AppConfig) *jimeng.Client { + return jimeng.NewClient(config.ApiConfig.JimengConfig.AccessKey, config.ApiConfig.JimengConfig.SecretKey) + }), + fx.Provide(jimeng.NewService), + fx.Provide(jimeng.NewConsumer), + fx.Invoke(func(consumer *jimeng.Consumer) { + consumer.Start() + go consumer.MonitorQueue() + }), fx.Provide(service.NewUserService), fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewHuPiPay), @@ -496,6 +512,29 @@ func main() { group.GET("remove", h.Remove) group.GET("publish", h.Publish) }), + + // 即梦AI 路由 + fx.Invoke(func(s *core.AppServer, h *handler.JimengHandler) { + group := s.Engine.Group("/api/jimeng") + group.POST("text-to-image", h.TextToImage) + group.POST("image-to-image-portrait", h.ImageToImagePortrait) + group.POST("image-edit", h.ImageEdit) + group.POST("image-effects", h.ImageEffects) + group.POST("text-to-video", h.TextToVideo) + group.POST("image-to-video", h.ImageToVideo) + group.GET("jobs", h.Jobs) + group.GET("pending-count", h.PendingCount) + group.GET("remove", h.Remove) + group.POST("retry/:id", h.Retry) + }), + fx.Invoke(func(s *core.AppServer, h *admin.AdminJimengHandler) { + group := s.Engine.Group("/api/admin/jimeng") + group.GET("jobs", h.Jobs) + group.GET("job/:id", h.JobDetail) + group.DELETE("job/:id", h.Remove) + group.POST("batch-remove", h.BatchRemove) + group.GET("stats", h.Stats) + }), fx.Provide(admin.NewChatAppTypeHandler), fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) { group := s.Engine.Group("/api/admin/app/type") diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index 29cf4491..d4f4ea33 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -49,7 +49,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien // PushTask push a new mj task in to task queue func (s *Service) PushTask(task types.DallTask) { logger.Infof("add a new DALL-E task to the task list: %+v", task) - s.taskQueue.RPush(task) + if err := s.taskQueue.RPush(task); err != nil { + logger.Errorf("push dall-e task to queue failed: %v", err) + } } func (s *Service) Run() { diff --git a/api/service/jimeng/client.go b/api/service/jimeng/client.go new file mode 100644 index 00000000..291a0aae --- /dev/null +++ b/api/service/jimeng/client.go @@ -0,0 +1,332 @@ +package jimeng + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "geekai/logger" +) + +var clientLogger = logger.GetLogger() + +// Client 即梦API客户端 +type Client struct { + accessKey string + secretKey string + region string + service string + baseURL string + httpClient *http.Client +} + +// NewClient 创建即梦API客户端 +func NewClient(accessKey, secretKey string) *Client { + return &Client{ + accessKey: accessKey, + secretKey: secretKey, + region: "cn-north-1", + service: "cv", + baseURL: "https://visual.volcengineapi.com", + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// SubmitTask 提交任务 +func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) { + // 构建请求URL + queryParams := map[string]string{ + "Action": "CVSync2AsyncSubmitTask", + "Version": "2022-08-31", + } + + reqURL := c.buildURL(queryParams) + + // 序列化请求体 + reqBody, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal request body failed: %w", err) + } + + // 创建HTTP请求 + httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("create http request failed: %w", err) + } + + // 设置请求头 + httpReq.Header.Set("Content-Type", "application/json") + + // 签名请求 + if err := c.signRequest(httpReq, reqBody); err != nil { + return nil, fmt.Errorf("sign request failed: %w", err) + } + + // 发送请求 + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("send http request failed: %w", err) + } + defer resp.Body.Close() + + // 读取响应 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body failed: %w", err) + } + + clientLogger.Infof("Jimeng SubmitTask Response: %s", string(respBody)) + + // 解析响应 + var result SubmitTaskResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("unmarshal response failed: %w", err) + } + + return &result, nil +} + +// QueryTask 查询任务 +func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) { + // 构建请求URL + queryParams := map[string]string{ + "Action": "CVSync2AsyncGetResult", + "Version": "2022-08-31", + } + + reqURL := c.buildURL(queryParams) + + // 序列化请求体 + reqBody, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal request body failed: %w", err) + } + + // 创建HTTP请求 + httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("create http request failed: %w", err) + } + + // 设置请求头 + httpReq.Header.Set("Content-Type", "application/json") + + // 签名请求 + if err := c.signRequest(httpReq, reqBody); err != nil { + return nil, fmt.Errorf("sign request failed: %w", err) + } + + // 发送请求 + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("send http request failed: %w", err) + } + defer resp.Body.Close() + + // 读取响应 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body failed: %w", err) + } + + clientLogger.Infof("Jimeng QueryTask Response: %s", string(respBody)) + + // 解析响应 + var result QueryTaskResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("unmarshal response failed: %w", err) + } + + return &result, nil +} + +// SubmitSyncTask 提交同步任务(仅用于文生图) +func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) { + // 构建请求URL + queryParams := map[string]string{ + "Action": "CVProcess", + "Version": "2022-08-31", + } + + reqURL := c.buildURL(queryParams) + + // 序列化请求体 + reqBody, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal request body failed: %w", err) + } + + // 创建HTTP请求 + httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("create http request failed: %w", err) + } + + // 设置请求头 + httpReq.Header.Set("Content-Type", "application/json") + + // 签名请求 + if err := c.signRequest(httpReq, reqBody); err != nil { + return nil, fmt.Errorf("sign request failed: %w", err) + } + + // 发送请求 + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("send http request failed: %w", err) + } + defer resp.Body.Close() + + // 读取响应 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body failed: %w", err) + } + + clientLogger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody)) + + // 解析响应 + var result QueryTaskResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("unmarshal response failed: %w", err) + } + + return &result, nil +} + +// buildURL 构建请求URL +func (c *Client) buildURL(queryParams map[string]string) string { + u, _ := url.Parse(c.baseURL) + q := u.Query() + for k, v := range queryParams { + q.Set(k, v) + } + u.RawQuery = q.Encode() + return u.String() +} + +// signRequest 签名请求 +func (c *Client) signRequest(req *http.Request, body []byte) error { + now := time.Now().UTC() + + // 设置基本头部 + req.Header.Set("X-Date", now.Format("20060102T150405Z")) + req.Header.Set("Host", req.URL.Host) + + // 计算内容哈希 + contentHash := sha256.Sum256(body) + req.Header.Set("X-Content-Sha256", hex.EncodeToString(contentHash[:])) + + // 构建签名字符串 + canonicalRequest := c.buildCanonicalRequest(req) + credentialScope := fmt.Sprintf("%s/%s/%s/request", now.Format("20060102"), c.region, c.service) + stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s", + now.Format("20060102T150405Z"), credentialScope, sha256Hash(canonicalRequest)) + + // 计算签名 + signature := c.calculateSignature(stringToSign, now) + + // 设置Authorization头部 + authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", + c.accessKey, credentialScope, c.getSignedHeaders(req), signature) + req.Header.Set("Authorization", authorization) + + return nil +} + +// buildCanonicalRequest 构建规范请求 +func (c *Client) buildCanonicalRequest(req *http.Request) string { + // HTTP方法 + method := req.Method + + // 规范URI + uri := req.URL.Path + if uri == "" { + uri = "/" + } + + // 规范查询字符串 + query := req.URL.Query() + var queryParts []string + for k, v := range query { + for _, val := range v { + queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(val))) + } + } + sort.Strings(queryParts) + canonicalQuery := strings.Join(queryParts, "&") + + // 规范头部 + var headerParts []string + headers := make(map[string]string) + for k, v := range req.Header { + key := strings.ToLower(k) + if len(v) > 0 { + headers[key] = strings.TrimSpace(v[0]) + } + } + + var headerKeys []string + for k := range headers { + headerKeys = append(headerKeys, k) + } + sort.Strings(headerKeys) + + for _, k := range headerKeys { + headerParts = append(headerParts, fmt.Sprintf("%s:%s", k, headers[k])) + } + canonicalHeaders := strings.Join(headerParts, "\n") + "\n" + + // 签名头部 + signedHeaders := c.getSignedHeaders(req) + + // 载荷哈希 + payloadHash := req.Header.Get("X-Content-Sha256") + + return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", + method, uri, canonicalQuery, canonicalHeaders, signedHeaders, payloadHash) +} + +// getSignedHeaders 获取签名头部 +func (c *Client) getSignedHeaders(req *http.Request) string { + var headers []string + for k := range req.Header { + headers = append(headers, strings.ToLower(k)) + } + sort.Strings(headers) + return strings.Join(headers, ";") +} + +// calculateSignature 计算签名 +func (c *Client) calculateSignature(stringToSign string, t time.Time) string { + kDate := hmacSha256([]byte("HMAC-SHA256"+c.secretKey), []byte(t.Format("20060102"))) + kRegion := hmacSha256(kDate, []byte(c.region)) + kService := hmacSha256(kRegion, []byte(c.service)) + kSigning := hmacSha256(kService, []byte("request")) + signature := hmacSha256(kSigning, []byte(stringToSign)) + return hex.EncodeToString(signature) +} + +// hmacSha256 计算HMAC-SHA256 +func hmacSha256(key []byte, data []byte) []byte { + h := hmac.New(sha256.New, key) + h.Write(data) + return h.Sum(nil) +} + +// sha256Hash 计算SHA256哈希 +func sha256Hash(data string) string { + hash := sha256.Sum256([]byte(data)) + return hex.EncodeToString(hash[:]) +} diff --git a/api/service/jimeng/consumer.go b/api/service/jimeng/consumer.go new file mode 100644 index 00000000..00f20061 --- /dev/null +++ b/api/service/jimeng/consumer.go @@ -0,0 +1,177 @@ +package jimeng + +import ( + "context" + "time" + + "geekai/logger" + "geekai/store/model" +) + +var jimengLogger = logger.GetLogger() + +// Consumer 即梦任务消费者 +type Consumer struct { + service *Service + ctx context.Context + cancel context.CancelFunc +} + +// NewConsumer 创建即梦任务消费者 +func NewConsumer(service *Service) *Consumer { + ctx, cancel := context.WithCancel(context.Background()) + return &Consumer{ + service: service, + ctx: ctx, + cancel: cancel, + } +} + +// Start 启动消费者 +func (c *Consumer) Start() { + jimengLogger.Info("Starting Jimeng task consumer...") + go c.consume() +} + +// Stop 停止消费者 +func (c *Consumer) Stop() { + jimengLogger.Info("Stopping Jimeng task consumer...") + c.cancel() +} + +// consume 消费任务 +func (c *Consumer) consume() { + for { + select { + case <-c.ctx.Done(): + jimengLogger.Info("Jimeng task consumer stopped") + return + default: + c.processTask() + } + } +} + +// processTask 处理任务 +func (c *Consumer) processTask() { + // 从队列中获取任务 + var task map[string]interface{} + if err := c.service.taskQueue.LPop(&task); err != nil { + // 队列为空,等待1秒后重试 + time.Sleep(time.Second) + return + } + + // 解析任务 + jobIdFloat, ok := task["job_id"].(float64) + if !ok { + jimengLogger.Errorf("invalid job_id in task: %v", task) + return + } + jobId := uint(jobIdFloat) + + taskType, ok := task["type"].(string) + if !ok { + jimengLogger.Errorf("invalid task type in task: %v", task) + return + } + + jimengLogger.Infof("Processing Jimeng task: job_id=%d, type=%s", jobId, taskType) + + // 处理任务 + if err := c.service.ProcessTask(jobId); err != nil { + jimengLogger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err) + + // 任务失败,直接标记为失败状态,不进行重试 + c.service.UpdateJobStatus(jobId, model.JimengJobStatusFailed, err.Error()) + } else { + jimengLogger.Infof("Jimeng task processed successfully: job_id=%d", jobId) + } +} + +// TaskQueueStatus 任务队列状态 +type TaskQueueStatus struct { + QueueLength int `json:"queue_length"` + ActiveTasks int `json:"active_tasks"` +} + +// GetQueueStatus 获取队列状态 +func (c *Consumer) GetQueueStatus() (*TaskQueueStatus, error) { + // 获取队列长度 + length, err := c.service.taskQueue.Size() + if err != nil { + return nil, err + } + + // 获取活跃任务数(正在处理的任务) + activeTasks, err := c.service.GetPendingTaskCount(0) // 0表示所有用户 + if err != nil { + activeTasks = 0 + } + + return &TaskQueueStatus{ + QueueLength: int(length), + ActiveTasks: int(activeTasks), + }, nil +} + +// MonitorQueue 监控队列状态 +func (c *Consumer) MonitorQueue() { + ticker := time.NewTicker(30 * time.Second) // 每30秒监控一次 + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + status, err := c.GetQueueStatus() + if err != nil { + jimengLogger.Errorf("get queue status failed: %v", err) + continue + } + + if status.QueueLength > 0 || status.ActiveTasks > 0 { + jimengLogger.Infof("Jimeng queue status: queue_length=%d, active_tasks=%d", + status.QueueLength, status.ActiveTasks) + } + } + } +} + +// PushTaskToQueue 推送任务到队列(用于手动重试) +func (c *Consumer) PushTaskToQueue(task map[string]interface{}) error { + return c.service.taskQueue.RPush(task) +} + +// GetTaskStats 获取任务统计信息 +func (c *Consumer) GetTaskStats() (map[string]interface{}, error) { + type StatResult struct { + Status string `json:"status"` + Count int64 `json:"count"` + } + + var stats []StatResult + err := c.service.db.Model(&model.JimengJob{}). + Select("status, COUNT(*) as count"). + Group("status"). + Find(&stats).Error + if err != nil { + return nil, err + } + + result := map[string]interface{}{ + "total": int64(0), + "completed": int64(0), + "processing": int64(0), + "failed": int64(0), + "pending": int64(0), + } + + for _, stat := range stats { + result["total"] = result["total"].(int64) + stat.Count + result[stat.Status] = stat.Count + } + + return result, nil +} \ No newline at end of file diff --git a/api/service/jimeng/service.go b/api/service/jimeng/service.go new file mode 100644 index 00000000..b011f55f --- /dev/null +++ b/api/service/jimeng/service.go @@ -0,0 +1,633 @@ +package jimeng + +import ( + "encoding/json" + "fmt" + "strconv" + "time" + + "gorm.io/gorm" + + "geekai/logger" + "geekai/store" + "geekai/store/model" + "geekai/utils" + + "github.com/go-redis/redis/v8" +) + +var serviceLogger = logger.GetLogger() + +// Service 即梦服务 +type Service struct { + db *gorm.DB + redis *redis.Client + taskQueue *store.RedisQueue + client *Client +} + +// NewService 创建即梦服务 +func NewService(db *gorm.DB, redisCli *redis.Client, client *Client) *Service { + taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli) + return &Service{ + db: db, + redis: redisCli, + taskQueue: taskQueue, + client: client, + } +} + +// CreateTask 创建任务 +func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) { + // 生成任务ID + taskId := utils.RandString(20) + + // 序列化任务参数 + paramsJson, err := json.Marshal(req.Params) + if err != nil { + return nil, fmt.Errorf("marshal task params failed: %w", err) + } + + // 创建任务记录 + job := &model.JimengJob{ + UserId: userId, + TaskId: taskId, + Type: req.Type, + ReqKey: req.ReqKey, + Prompt: req.Prompt, + TaskParams: string(paramsJson), + Status: model.JimengJobStatusPending, + Power: req.Power, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + // 保存到数据库 + if err := s.db.Create(job).Error; err != nil { + return nil, fmt.Errorf("create jimeng job failed: %w", err) + } + + // 推送到任务队列 + task := map[string]any{ + "job_id": job.Id, + "type": job.Type, + } + if err := s.taskQueue.RPush(task); err != nil { + return nil, fmt.Errorf("push jimeng task to queue failed: %w", err) + } + + return job, nil +} + +// ProcessTask 处理任务 +func (s *Service) ProcessTask(jobId uint) error { + // 获取任务记录 + var job model.JimengJob + if err := s.db.First(&job, jobId).Error; err != nil { + return fmt.Errorf("get jimeng job failed: %w", err) + } + + // 更新任务状态为处理中 + if err := s.UpdateJobStatus(job.Id, model.JimengJobStatusProcessing, ""); err != nil { + return fmt.Errorf("update job status failed: %w", err) + } + + // 根据任务类型处理 + switch job.Type { + case model.JimengJobTypeTextToImage: + return s.processTextToImage(&job) + case model.JimengJobTypeImageToImagePortrait: + return s.processImageToImagePortrait(&job) + case model.JimengJobTypeImageEdit: + return s.processImageEdit(&job) + case model.JimengJobTypeImageEffects: + return s.processImageEffects(&job) + case model.JimengJobTypeTextToVideo: + return s.processTextToVideo(&job) + case model.JimengJobTypeImageToVideo: + return s.processImageToVideo(&job) + default: + return fmt.Errorf("unsupported task type: %s", job.Type) + } +} + +// processTextToImage 处理文生图任务 +func (s *Service) processTextToImage(job *model.JimengJob) error { + // 解析任务参数 + var params map[string]any + if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) + } + + // 构建请求 + req := &SubmitTaskRequest{ + ReqKey: job.ReqKey, + Prompt: job.Prompt, + } + + // 设置参数 + if seed, ok := params["seed"]; ok { + if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { + req.Seed = seedVal + } + } + if scale, ok := params["scale"]; ok { + if scaleVal, ok := scale.(float64); ok { + req.Scale = scaleVal + } + } + if width, ok := params["width"]; ok { + if widthVal, ok := width.(float64); ok { + req.Width = int(widthVal) + } + } + if height, ok := params["height"]; ok { + if heightVal, ok := height.(float64); ok { + req.Height = int(heightVal) + } + } + if usePreLlm, ok := params["use_pre_llm"]; ok { + if usePreLlmVal, ok := usePreLlm.(bool); ok { + req.UsePreLLM = usePreLlmVal + } + } + + // 提交异步任务 + resp, err := s.client.SubmitTask(req) + if err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) + } + + if resp.Code != 10000 { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) + } + + // 更新任务ID和原始数据 + rawData, _ := json.Marshal(resp) + if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ + "task_id": resp.Data.TaskId, + "raw_data": string(rawData), + "updated_at": time.Now(), + }).Error; err != nil { + serviceLogger.Errorf("update jimeng job task_id failed: %v", err) + } + + // 开始轮询任务状态 + return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) +} + +// processImageToImagePortrait 处理图生图人像写真任务 +func (s *Service) processImageToImagePortrait(job *model.JimengJob) error { + // 解析任务参数 + var params map[string]any + if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) + } + + // 构建请求 + req := &SubmitTaskRequest{ + ReqKey: job.ReqKey, + Prompt: job.Prompt, + } + + // 设置图像输入 + if imageInput, ok := params["image_input"].(string); ok { + req.ImageInput = imageInput + } + + // 设置其他参数 + if gpen, ok := params["gpen"]; ok { + if gpenVal, ok := gpen.(float64); ok { + req.Gpen = gpenVal + } + } + if skin, ok := params["skin"]; ok { + if skinVal, ok := skin.(float64); ok { + req.Skin = skinVal + } + } + if skinUnifi, ok := params["skin_unifi"]; ok { + if skinUnifiVal, ok := skinUnifi.(float64); ok { + req.SkinUnifi = skinUnifiVal + } + } + if genMode, ok := params["gen_mode"].(string); ok { + req.GenMode = genMode + } + if width, ok := params["width"]; ok { + if widthVal, ok := width.(float64); ok { + req.Width = int(widthVal) + } + } + if height, ok := params["height"]; ok { + if heightVal, ok := height.(float64); ok { + req.Height = int(heightVal) + } + } + if seed, ok := params["seed"]; ok { + if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { + req.Seed = seedVal + } + } + + // 提交异步任务 + resp, err := s.client.SubmitTask(req) + if err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) + } + + if resp.Code != 10000 { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) + } + + // 更新任务ID和原始数据 + rawData, _ := json.Marshal(resp) + if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ + "task_id": resp.Data.TaskId, + "raw_data": string(rawData), + "updated_at": time.Now(), + }).Error; err != nil { + serviceLogger.Errorf("update jimeng job task_id failed: %v", err) + } + + // 开始轮询任务状态 + return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) +} + +// processImageEdit 处理图像编辑任务 +func (s *Service) processImageEdit(job *model.JimengJob) error { + // 解析任务参数 + var params map[string]any + if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) + } + + // 构建请求 + req := &SubmitTaskRequest{ + ReqKey: job.ReqKey, + Prompt: job.Prompt, + } + + // 设置图像输入 + if imageUrls, ok := params["image_urls"].([]any); ok { + for _, url := range imageUrls { + if urlStr, ok := url.(string); ok { + req.ImageUrls = append(req.ImageUrls, urlStr) + } + } + } + if binaryData, ok := params["binary_data_base64"].([]any); ok { + for _, data := range binaryData { + if dataStr, ok := data.(string); ok { + req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr) + } + } + } + + // 设置其他参数 + if seed, ok := params["seed"]; ok { + if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { + req.Seed = seedVal + } + } + if scale, ok := params["scale"]; ok { + if scaleVal, ok := scale.(float64); ok { + req.Scale = scaleVal + } + } + + // 提交异步任务 + resp, err := s.client.SubmitTask(req) + if err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) + } + + if resp.Code != 10000 { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) + } + + // 更新任务ID和原始数据 + rawData, _ := json.Marshal(resp) + if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ + "task_id": resp.Data.TaskId, + "raw_data": string(rawData), + "updated_at": time.Now(), + }).Error; err != nil { + serviceLogger.Errorf("update jimeng job task_id failed: %v", err) + } + + // 开始轮询任务状态 + return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) +} + +// processImageEffects 处理图像特效任务 +func (s *Service) processImageEffects(job *model.JimengJob) error { + // 解析任务参数 + var params map[string]any + if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) + } + + // 构建请求 + req := &SubmitTaskRequest{ + ReqKey: job.ReqKey, + } + + // 设置图像输入 + if imageInput1, ok := params["image_input1"].(string); ok { + req.ImageInput1 = imageInput1 + } + if templateId, ok := params["template_id"].(string); ok { + req.TemplateId = templateId + } + if width, ok := params["width"]; ok { + if widthVal, ok := width.(float64); ok { + req.Width = int(widthVal) + } + } + if height, ok := params["height"]; ok { + if heightVal, ok := height.(float64); ok { + req.Height = int(heightVal) + } + } + + // 提交异步任务 + resp, err := s.client.SubmitTask(req) + if err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) + } + + if resp.Code != 10000 { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) + } + + // 更新任务ID和原始数据 + rawData, _ := json.Marshal(resp) + if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ + "task_id": resp.Data.TaskId, + "raw_data": string(rawData), + "updated_at": time.Now(), + }).Error; err != nil { + serviceLogger.Errorf("update jimeng job task_id failed: %v", err) + } + + // 开始轮询任务状态 + return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) +} + +// processTextToVideo 处理文生视频任务 +func (s *Service) processTextToVideo(job *model.JimengJob) error { + // 解析任务参数 + var params map[string]any + if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) + } + + // 构建请求 + req := &SubmitTaskRequest{ + ReqKey: job.ReqKey, + Prompt: job.Prompt, + } + + // 设置参数 + if seed, ok := params["seed"]; ok { + if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { + req.Seed = seedVal + } + } + if aspectRatio, ok := params["aspect_ratio"].(string); ok { + req.AspectRatio = aspectRatio + } + + // 提交异步任务 + resp, err := s.client.SubmitTask(req) + if err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) + } + + if resp.Code != 10000 { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) + } + + // 更新任务ID和原始数据 + rawData, _ := json.Marshal(resp) + if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ + "task_id": resp.Data.TaskId, + "raw_data": string(rawData), + "updated_at": time.Now(), + }).Error; err != nil { + serviceLogger.Errorf("update jimeng job task_id failed: %v", err) + } + + // 开始轮询任务状态 + return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) +} + +// processImageToVideo 处理图生视频任务 +func (s *Service) processImageToVideo(job *model.JimengJob) error { + // 解析任务参数 + var params map[string]any + if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err)) + } + + // 构建请求 + req := &SubmitTaskRequest{ + ReqKey: job.ReqKey, + Prompt: job.Prompt, + } + + // 设置图像输入 + if imageUrls, ok := params["image_urls"].([]any); ok { + for _, url := range imageUrls { + if urlStr, ok := url.(string); ok { + req.ImageUrls = append(req.ImageUrls, urlStr) + } + } + } + if binaryData, ok := params["binary_data_base64"].([]any); ok { + for _, data := range binaryData { + if dataStr, ok := data.(string); ok { + req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr) + } + } + } + + // 设置其他参数 + if seed, ok := params["seed"]; ok { + if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil { + req.Seed = seedVal + } + } + if aspectRatio, ok := params["aspect_ratio"].(string); ok { + req.AspectRatio = aspectRatio + } + + // 提交异步任务 + resp, err := s.client.SubmitTask(req) + if err != nil { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err)) + } + + if resp.Code != 10000 { + return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message)) + } + + // 更新任务ID和原始数据 + rawData, _ := json.Marshal(resp) + if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{ + "task_id": resp.Data.TaskId, + "raw_data": string(rawData), + "updated_at": time.Now(), + }).Error; err != nil { + serviceLogger.Errorf("update jimeng job task_id failed: %v", err) + } + + // 开始轮询任务状态 + return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey) +} + +// pollTaskStatus 轮询任务状态 +func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error { + maxRetries := 60 // 最大重试次数,60次 * 5秒 = 5分钟 + retryCount := 0 + + for retryCount < maxRetries { + time.Sleep(5 * time.Second) // 等待5秒 + + // 查询任务状态 + resp, err := s.client.QueryTask(&QueryTaskRequest{ + ReqKey: reqKey, + TaskId: taskId, + ReqJson: `{"return_url":true}`, + }) + + if err != nil { + serviceLogger.Errorf("query jimeng task status failed: %v", err) + retryCount++ + continue + } + + // 更新原始数据 + rawData, _ := json.Marshal(resp) + s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Update("raw_data", string(rawData)) + + if resp.Code != 10000 { + return s.handleTaskError(jobId, fmt.Sprintf("query task failed: %s", resp.Message)) + } + + switch resp.Data.Status { + case TaskStatusDone: + // 任务完成,更新结果 + updates := map[string]any{ + "status": model.JimengJobStatusCompleted, + "progress": 100, + "updated_at": time.Now(), + } + + // 设置结果URL + if len(resp.Data.ImageUrls) > 0 { + updates["img_url"] = resp.Data.ImageUrls[0] + } + if resp.Data.VideoUrl != "" { + updates["video_url"] = resp.Data.VideoUrl + } + + return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error + + case TaskStatusInQueue: + // 任务在队列中 + s.UpdateJobProgress(jobId, 10) + + case TaskStatusGenerating: + // 任务处理中 + s.UpdateJobProgress(jobId, 50) + + case TaskStatusNotFound, TaskStatusExpired: + // 任务未找到或已过期 + return s.handleTaskError(jobId, fmt.Sprintf("task not found or expired: %s", resp.Data.Status)) + + default: + serviceLogger.Warnf("unknown task status: %s", resp.Data.Status) + } + + retryCount++ + } + + // 超时处理 + return s.handleTaskError(jobId, "task timeout") +} + +// UpdateJobStatus 更新任务状态 +func (s *Service) UpdateJobStatus(jobId uint, status, errMsg string) error { + updates := map[string]any{ + "status": status, + "updated_at": time.Now(), + } + if errMsg != "" { + updates["err_msg"] = errMsg + } + return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error +} + +// UpdateJobProgress 更新任务进度 +func (s *Service) UpdateJobProgress(jobId uint, progress int) error { + return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(map[string]any{ + "progress": progress, + "updated_at": time.Now(), + }).Error +} + +// handleTaskError 处理任务错误 +func (s *Service) handleTaskError(jobId uint, errMsg string) error { + serviceLogger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg) + return s.UpdateJobStatus(jobId, model.JimengJobStatusFailed, errMsg) +} + +// GetJob 获取任务 +func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) { + var job model.JimengJob + if err := s.db.First(&job, jobId).Error; err != nil { + return nil, err + } + return &job, nil +} + +// GetUserJobs 获取用户任务列表 +func (s *Service) GetUserJobs(userId uint, page, pageSize int) ([]*model.JimengJob, int64, error) { + var jobs []*model.JimengJob + var total int64 + + query := s.db.Model(&model.JimengJob{}).Where("user_id = ?", userId) + + // 统计总数 + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // 分页查询 + offset := (page - 1) * pageSize + if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&jobs).Error; err != nil { + return nil, 0, err + } + + return jobs, total, nil +} + +// GetPendingTaskCount 获取用户未完成任务数量 +func (s *Service) GetPendingTaskCount(userId uint) (int64, error) { + var count int64 + err := s.db.Model(&model.JimengJob{}).Where("user_id = ? AND status IN (?)", userId, + []string{model.JimengJobStatusPending, model.JimengJobStatusProcessing}).Count(&count).Error + return count, err +} + +// DeleteJob 删除任务 +func (s *Service) DeleteJob(jobId uint, userId uint) error { + return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error +} + +// PushTaskToQueue 推送任务到队列 +func (s *Service) PushTaskToQueue(task map[string]interface{}) error { + return s.taskQueue.RPush(task) +} diff --git a/api/service/jimeng/types.go b/api/service/jimeng/types.go new file mode 100644 index 00000000..029e0f91 --- /dev/null +++ b/api/service/jimeng/types.go @@ -0,0 +1,163 @@ +package jimeng + +import "time" + +// SubmitTaskRequest 提交任务请求 +type SubmitTaskRequest struct { + ReqKey string `json:"req_key"` + // 文生图参数 + Prompt string `json:"prompt,omitempty"` + Seed int64 `json:"seed,omitempty"` + Scale float64 `json:"scale,omitempty"` + Width int `json:"width,omitempty"` + Height int `json:"height,omitempty"` + UsePreLLM bool `json:"use_pre_llm,omitempty"` + // 图生图参数 + ImageInput string `json:"image_input,omitempty"` + ImageUrls []string `json:"image_urls,omitempty"` + BinaryDataBase64 []string `json:"binary_data_base64,omitempty"` + Gpen float64 `json:"gpen,omitempty"` + Skin float64 `json:"skin,omitempty"` + SkinUnifi float64 `json:"skin_unifi,omitempty"` + GenMode string `json:"gen_mode,omitempty"` + // 图像编辑参数 + // 图像特效参数 + ImageInput1 string `json:"image_input1,omitempty"` + TemplateId string `json:"template_id,omitempty"` + // 视频生成参数 + AspectRatio string `json:"aspect_ratio,omitempty"` +} + +// SubmitTaskResponse 提交任务响应 +type SubmitTaskResponse struct { + Code int `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` + Status int `json:"status"` + TimeElapsed string `json:"time_elapsed"` + Data struct { + TaskId string `json:"task_id"` + } `json:"data"` +} + +// QueryTaskRequest 查询任务请求 +type QueryTaskRequest struct { + ReqKey string `json:"req_key"` + TaskId string `json:"task_id"` + ReqJson string `json:"req_json,omitempty"` +} + +// QueryTaskResponse 查询任务响应 +type QueryTaskResponse struct { + Code int `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` + Status int `json:"status"` + TimeElapsed string `json:"time_elapsed"` + Data struct { + AlgorithmBaseResp struct { + StatusCode int `json:"status_code"` + StatusMessage string `json:"status_message"` + } `json:"algorithm_base_resp"` + BinaryDataBase64 []string `json:"binary_data_base64"` + ImageUrls []string `json:"image_urls"` + VideoUrl string `json:"video_url"` + RespData string `json:"resp_data"` + Status string `json:"status"` + LlmResult string `json:"llm_result"` + PeResult string `json:"pe_result"` + PredictTagsResult string `json:"predict_tags_result"` + RephraserResult string `json:"rephraser_result"` + VlmResult string `json:"vlm_result"` + InferCtx interface{} `json:"infer_ctx"` + } `json:"data"` +} + +// TaskStatus 任务状态 +const ( + TaskStatusInQueue = "in_queue" // 任务已提交 + TaskStatusGenerating = "generating" // 任务处理中 + TaskStatusDone = "done" // 处理完成 + TaskStatusNotFound = "not_found" // 任务未找到 + TaskStatusExpired = "expired" // 任务已过期 +) + +// CreateTaskRequest 创建任务请求 +type CreateTaskRequest struct { + Type string `json:"type"` + Prompt string `json:"prompt"` + Params map[string]interface{} `json:"params"` + ReqKey string `json:"req_key"` + ImageUrls []string `json:"image_urls,omitempty"` + Power int `json:"power,omitempty"` +} + +// TaskInfo 任务信息 +type TaskInfo struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + TaskId string `json:"task_id"` + Type string `json:"type"` + ReqKey string `json:"req_key"` + Prompt string `json:"prompt"` + TaskParams string `json:"task_params"` + ImgURL string `json:"img_url"` + VideoURL string `json:"video_url"` + Progress int `json:"progress"` + Status string `json:"status"` + ErrMsg string `json:"err_msg"` + Power int `json:"power"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// LogoInfo 水印信息 +type LogoInfo struct { + AddLogo bool `json:"add_logo"` + Position int `json:"position"` + Language int `json:"language"` + Opacity float64 `json:"opacity"` + LogoTextContent string `json:"logo_text_content"` +} + +// ReqJsonConfig 查询配置 +type ReqJsonConfig struct { + ReturnUrl bool `json:"return_url"` + LogoInfo *LogoInfo `json:"logo_info,omitempty"` +} + +// ImageEffectTemplate 图像特效模板 +const ( + TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格 + TemplateIdMyWorld = "my_world" // 像素世界风 + TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版 + TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风 + TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版 + TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风 + TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风 + TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版 + TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰 + TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣 + TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡 + TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办 + TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里 + TemplateIdGlassBall = "glass_ball" // 玻璃球 +) + +// AspectRatio 视频宽高比 +const ( + AspectRatio16_9 = "16:9" // 1280×720 + AspectRatio9_16 = "9:16" // 720×1280 + AspectRatio1_1 = "1:1" // 960×960 + AspectRatio4_3 = "4:3" // 960×720 + AspectRatio3_4 = "3:4" // 720×960 + AspectRatio21_9 = "21:9" // 1680×720 + AspectRatio9_21 = "9:21" // 720×1680 +) + +// GenMode 生成模式 +const ( + GenModeCreative = "creative" // 提示词模式 + GenModeReference = "reference" // 全参考模式 + GenModeReferenceChar = "reference_char" // 人物参考模式 +) \ No newline at end of file diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 532b82fd..9253d84b 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -212,7 +212,9 @@ func (s *Service) DownloadImages() { // PushTask push a new mj task in to task queue func (s *Service) PushTask(task types.MjTask) { logger.Debugf("add a new MidJourney task to the task list: %+v", task) - s.taskQueue.RPush(task) + if err := s.taskQueue.RPush(task); err != nil { + logger.Errorf("push mj task to queue failed: %v", err) + } } // SyncTaskProgress 异步拉取任务 diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 2047923e..f9e4437c 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -253,7 +253,9 @@ func (s *Service) checkTaskProgress(apiKey model.ApiKey) (*TaskProgressResp, err func (s *Service) PushTask(task types.SdTask) { logger.Debugf("add a new MidJourney task to the task list: %+v", task) - s.taskQueue.RPush(task) + if err := s.taskQueue.RPush(task); err != nil { + logger.Errorf("push sd task to queue failed: %v", err) + } } // CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 diff --git a/api/service/suno/service.go b/api/service/suno/service.go index 59e2aecc..2a5a457f 100644 --- a/api/service/suno/service.go +++ b/api/service/suno/service.go @@ -51,7 +51,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien func (s *Service) PushTask(task types.SunoTask) { logger.Infof("add a new Suno task to the task list: %+v", task) - s.taskQueue.RPush(task) + if err := s.taskQueue.RPush(task); err != nil { + logger.Errorf("push suno task to queue failed: %v", err) + } } func (s *Service) Run() { diff --git a/api/service/video/video.go b/api/service/video/video.go index 42628c13..9b82b26a 100644 --- a/api/service/video/video.go +++ b/api/service/video/video.go @@ -51,7 +51,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien func (s *Service) PushTask(task types.VideoTask) { logger.Infof("add a new Video task to the task list: %+v", task) - s.taskQueue.RPush(task) + if err := s.taskQueue.RPush(task); err != nil { + logger.Errorf("push video task to queue failed: %v", err) + } } func (s *Service) Run() { diff --git a/api/store/model/jimeng_job.go b/api/store/model/jimeng_job.go new file mode 100644 index 00000000..5a16b027 --- /dev/null +++ b/api/store/model/jimeng_job.go @@ -0,0 +1,58 @@ +package model + +import ( + "time" +) + +// JimengJob 即梦AI任务模型 +type JimengJob struct { + Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"` + UserId uint `gorm:"column:user_id;type:int;not null;index;comment:用户ID" json:"user_id"` + TaskId string `gorm:"column:task_id;type:varchar(100);not null;index;comment:任务ID" json:"task_id"` + Type string `gorm:"column:type;type:varchar(50);not null;comment:任务类型" json:"type"` + ReqKey string `gorm:"column:req_key;type:varchar(100);comment:请求Key" json:"req_key"` + Prompt string `gorm:"column:prompt;type:text;comment:提示词" json:"prompt"` + TaskParams string `gorm:"column:task_params;type:text;comment:任务参数JSON" json:"task_params"` + ImgURL string `gorm:"column:img_url;type:varchar(1024);comment:图片或封面URL" json:"img_url"` + VideoURL string `gorm:"column:video_url;type:varchar(1024);comment:视频URL" json:"video_url"` + RawData string `gorm:"column:raw_data;type:text;comment:原始API响应" json:"raw_data"` + Progress int `gorm:"column:progress;type:int;default:0;comment:进度百分比" json:"progress"` + Status string `gorm:"column:status;type:varchar(20);default:'pending';comment:任务状态" json:"status"` + ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"` + Power int `gorm:"column:power;type:int;default:0;comment:消耗算力" json:"power"` + CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;comment:更新时间" json:"updated_at"` +} + +// JimengJobStatus 即梦任务状态常量 +const ( + JimengJobStatusPending = "pending" + JimengJobStatusProcessing = "processing" + JimengJobStatusCompleted = "completed" + JimengJobStatusFailed = "failed" +) + +// JimengJobType 即梦任务类型常量 +const ( + JimengJobTypeTextToImage = "text_to_image" // 文生图 + JimengJobTypeImageToImagePortrait = "image_to_image_portrait" // 图生图人像写真 + JimengJobTypeImageEdit = "image_edit" // 图像编辑 + JimengJobTypeImageEffects = "image_effects" // 图像特效 + JimengJobTypeTextToVideo = "text_to_video" // 文生视频 + JimengJobTypeImageToVideo = "image_to_video" // 图生视频 +) + +// ReqKey 常量定义 +const ( + ReqKeyTextToImage = "high_aes_general_v30l_zt2i" // 文生图 + ReqKeyImageToImagePortrait = "i2i_portrait_photo" // 图生图人像写真 + ReqKeyImageEdit = "seededit_v3.0" // 图像编辑 + ReqKeyImageEffects = "i2i_multi_style_zx2x" // 图像特效 + ReqKeyTextToVideo = "jimeng_vgfm_t2v_l20" // 文生视频 + ReqKeyImageToVideo = "jimeng_vgfm_i2v_l20" // 图生视频 +) + +// TableName 返回数据表名称 +func (JimengJob) TableName() string { + return "chatgpt_jimeng_jobs" +} diff --git a/api/store/redis_queue.go b/api/store/redis_queue.go index 3251eb57..71e6378b 100644 --- a/api/store/redis_queue.go +++ b/api/store/redis_queue.go @@ -10,6 +10,7 @@ package store import ( "context" "geekai/utils" + "github.com/go-redis/redis/v8" ) @@ -23,15 +24,15 @@ func NewRedisQueue(name string, client *redis.Client) *RedisQueue { return &RedisQueue{name: name, client: client, ctx: context.Background()} } -func (q *RedisQueue) RPush(value interface{}) { - q.client.RPush(q.ctx, q.name, utils.JsonEncode(value)) +func (q *RedisQueue) RPush(value any) error { + return q.client.RPush(q.ctx, q.name, utils.JsonEncode(value)).Err() } -func (q *RedisQueue) LPush(value interface{}) { - q.client.LPush(q.ctx, q.name, utils.JsonEncode(value)) +func (q *RedisQueue) LPush(value any) error { + return q.client.LPush(q.ctx, q.name, utils.JsonEncode(value)).Err() } -func (q *RedisQueue) LPop(value interface{}) error { +func (q *RedisQueue) LPop(value any) error { result, err := q.client.BLPop(q.ctx, 0, q.name).Result() if err != nil { return err @@ -39,10 +40,18 @@ func (q *RedisQueue) LPop(value interface{}) error { return utils.JsonDecode(result[1], value) } -func (q *RedisQueue) RPop(value interface{}) error { +func (q *RedisQueue) RPop(value any) error { result, err := q.client.BRPop(q.ctx, 0, q.name).Result() if err != nil { return err } return utils.JsonDecode(result[1], value) } + +func (q *RedisQueue) Size() (int64, error) { + return q.client.LLen(q.ctx, q.name).Result() +} + +func (q *RedisQueue) Clear() error { + return q.client.Del(q.ctx, q.name).Err() +} diff --git a/api/store/vo/jimeng_job.go b/api/store/vo/jimeng_job.go new file mode 100644 index 00000000..14b76817 --- /dev/null +++ b/api/store/vo/jimeng_job.go @@ -0,0 +1,21 @@ +package vo + +// JimengJob 即梦AI任务VO +type JimengJob struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + TaskId string `json:"task_id"` + Type string `json:"type"` + ReqKey string `json:"req_key"` + Prompt string `json:"prompt"` + TaskParams string `json:"task_params"` + ImgURL string `json:"img_url"` + VideoURL string `json:"video_url"` + RawData string `json:"raw_data"` + Progress int `json:"progress"` + Status string `json:"status"` + ErrMsg string `json:"err_msg"` + Power int `json:"power"` + CreatedAt int64 `json:"created_at"` // 时间戳 + UpdatedAt int64 `json:"updated_at"` // 时间戳 +} diff --git a/api/test/test.go b/api/test/test.go deleted file mode 100644 index 99fd702a..00000000 --- a/api/test/test.go +++ /dev/null @@ -1,55 +0,0 @@ -package main - -import ( - "crypto/rand" - "encoding/hex" - "fmt" - "sync" -) - -const ( - codeLength = 32 // 兑换码长度 -) - -var ( - codeMap = make(map[string]bool) - mapMutex = &sync.Mutex{} -) - -// GenerateUniqueCode 生成唯一兑换码 -func GenerateUniqueCode() (string, error) { - for { - code, err := generateCode() - if err != nil { - return "", err - } - - mapMutex.Lock() - if !codeMap[code] { - codeMap[code] = true - mapMutex.Unlock() - return code, nil - } - mapMutex.Unlock() - } -} - -// generateCode 生成兑换码 -func generateCode() (string, error) { - bytes := make([]byte, codeLength/2) // 因为 hex 编码会使长度翻倍 - if _, err := rand.Read(bytes); err != nil { - return "", err - } - return hex.EncodeToString(bytes), nil -} - -func main() { - for i := 0; i < 10; i++ { - code, err := GenerateUniqueCode() - if err != nil { - fmt.Println("Error generating code:", err) - return - } - fmt.Println("Generated code:", code) - } -} diff --git a/web/src/assets/iconfont/iconfont.css b/web/src/assets/iconfont/iconfont.css index 834e6462..73489687 100644 --- a/web/src/assets/iconfont/iconfont.css +++ b/web/src/assets/iconfont/iconfont.css @@ -1,8 +1,8 @@ @font-face { font-family: "iconfont"; /* Project id 4125778 */ - src: url('iconfont.woff2?t=1752731646117') format('woff2'), - url('iconfont.woff?t=1752731646117') format('woff'), - url('iconfont.ttf?t=1752731646117') format('truetype'); + src: url('iconfont.woff2?t=1752831319382') format('woff2'), + url('iconfont.woff?t=1752831319382') format('woff'), + url('iconfont.ttf?t=1752831319382') format('truetype'); } .iconfont { @@ -13,6 +13,14 @@ -moz-osx-font-smoothing: grayscale; } +.icon-jimeng2:before { + content: "\eabc"; +} + +.icon-jimeng:before { + content: "\eabb"; +} + .icon-video:before { content: "\e63f"; } diff --git a/web/src/assets/iconfont/iconfont.js b/web/src/assets/iconfont/iconfont.js index 0e090b2d..ffd1ad96 100644 --- a/web/src/assets/iconfont/iconfont.js +++ b/web/src/assets/iconfont/iconfont.js @@ -1 +1 @@ -window._iconfont_svg_string_4125778='',(a=>{var l=(c=(c=document.getElementsByTagName("script"))[c.length-1]).getAttribute("data-injectcss"),c=c.getAttribute("data-disable-injectsvg");if(!c){var h,t,i,o,z,m=function(l,c){c.parentNode.insertBefore(l,c)};if(l&&!a.__iconfont__svg__cssinject__){a.__iconfont__svg__cssinject__=!0;try{document.write("")}catch(l){console&&console.log(l)}}h=function(){var l,c=document.createElement("div");c.innerHTML=a._iconfont_svg_string_4125778,(c=c.getElementsByTagName("svg")[0])&&(c.setAttribute("aria-hidden","true"),c.style.position="absolute",c.style.width=0,c.style.height=0,c.style.overflow="hidden",c=c,(l=document.body).firstChild?m(c,l.firstChild):l.appendChild(c))},document.addEventListener?~["complete","loaded","interactive"].indexOf(document.readyState)?setTimeout(h,0):(t=function(){document.removeEventListener("DOMContentLoaded",t,!1),h()},document.addEventListener("DOMContentLoaded",t,!1)):document.attachEvent&&(i=h,o=a.document,z=!1,v(),o.onreadystatechange=function(){"complete"==o.readyState&&(o.onreadystatechange=null,p())})}function p(){z||(z=!0,i())}function v(){try{o.documentElement.doScroll("left")}catch(l){return void setTimeout(v,50)}p()}})(window); \ No newline at end of file +window._iconfont_svg_string_4125778='',(a=>{var l=(c=(c=document.getElementsByTagName("script"))[c.length-1]).getAttribute("data-injectcss"),c=c.getAttribute("data-disable-injectsvg");if(!c){var h,t,i,o,z,m=function(l,c){c.parentNode.insertBefore(l,c)};if(l&&!a.__iconfont__svg__cssinject__){a.__iconfont__svg__cssinject__=!0;try{document.write("")}catch(l){console&&console.log(l)}}h=function(){var l,c=document.createElement("div");c.innerHTML=a._iconfont_svg_string_4125778,(c=c.getElementsByTagName("svg")[0])&&(c.setAttribute("aria-hidden","true"),c.style.position="absolute",c.style.width=0,c.style.height=0,c.style.overflow="hidden",c=c,(l=document.body).firstChild?m(c,l.firstChild):l.appendChild(c))},document.addEventListener?~["complete","loaded","interactive"].indexOf(document.readyState)?setTimeout(h,0):(t=function(){document.removeEventListener("DOMContentLoaded",t,!1),h()},document.addEventListener("DOMContentLoaded",t,!1)):document.attachEvent&&(i=h,o=a.document,z=!1,v(),o.onreadystatechange=function(){"complete"==o.readyState&&(o.onreadystatechange=null,p())})}function p(){z||(z=!0,i())}function v(){try{o.documentElement.doScroll("left")}catch(l){return void setTimeout(v,50)}p()}})(window); \ No newline at end of file diff --git a/web/src/assets/iconfont/iconfont.json b/web/src/assets/iconfont/iconfont.json index 9f5dc208..4085727c 100644 --- a/web/src/assets/iconfont/iconfont.json +++ b/web/src/assets/iconfont/iconfont.json @@ -5,6 +5,20 @@ "css_prefix_text": "icon-", "description": "", "glyphs": [ + { + "icon_id": "42693930", + "name": "即梦AI-02", + "font_class": "jimeng2", + "unicode": "eabc", + "unicode_decimal": 60092 + }, + { + "icon_id": "42693927", + "name": "即梦AI-01", + "font_class": "jimeng", + "unicode": "eabb", + "unicode_decimal": 60091 + }, { "icon_id": "1283", "name": "视频", diff --git a/web/src/assets/iconfont/iconfont.ttf b/web/src/assets/iconfont/iconfont.ttf index 47385b32..e76c0a11 100644 Binary files a/web/src/assets/iconfont/iconfont.ttf and b/web/src/assets/iconfont/iconfont.ttf differ diff --git a/web/src/assets/iconfont/iconfont.woff b/web/src/assets/iconfont/iconfont.woff index 6029718b..5511edbd 100644 Binary files a/web/src/assets/iconfont/iconfont.woff and b/web/src/assets/iconfont/iconfont.woff differ diff --git a/web/src/assets/iconfont/iconfont.woff2 b/web/src/assets/iconfont/iconfont.woff2 index ceefa2f6..9ca0b337 100644 Binary files a/web/src/assets/iconfont/iconfont.woff2 and b/web/src/assets/iconfont/iconfont.woff2 differ diff --git a/web/src/components/ImageUpload.vue b/web/src/components/ImageUpload.vue new file mode 100644 index 00000000..d9d1f4d1 --- /dev/null +++ b/web/src/components/ImageUpload.vue @@ -0,0 +1,291 @@ + + + + + diff --git a/web/src/components/admin/AdminSidebar.vue b/web/src/components/admin/AdminSidebar.vue index 9375a8a5..8f852cdd 100644 --- a/web/src/components/admin/AdminSidebar.vue +++ b/web/src/components/admin/AdminSidebar.vue @@ -159,6 +159,11 @@ const items = [ index: '/admin/medias', title: '音视频记录', }, + { + icon: 'image', + index: '/admin/jimeng', + title: '即梦AI任务', + }, ], }, diff --git a/web/src/router.js b/web/src/router.js index 8e05f8ac..83093a66 100644 --- a/web/src/router.js +++ b/web/src/router.js @@ -109,6 +109,12 @@ const routes = [ meta: { title: '视频创作中心' }, component: () => import('@/views/Video.vue'), }, + { + name: 'jimeng', + path: '/jimeng', + meta: { title: '即梦AI' }, + component: () => import('@/views/Jimeng.vue'), + }, ], }, { @@ -252,6 +258,12 @@ const routes = [ meta: { title: '音视频管理' }, component: () => import('@/views/admin/records/Medias.vue'), }, + { + path: '/admin/jimeng', + name: 'admin-jimeng', + meta: { title: '即梦AI管理' }, + component: () => import('@/views/admin/JimengJobs.vue'), + }, { path: '/admin/powerLog', name: 'admin-power-log', diff --git a/web/src/store/jimeng.js b/web/src/store/jimeng.js new file mode 100644 index 00000000..33c0b2f9 --- /dev/null +++ b/web/src/store/jimeng.js @@ -0,0 +1,513 @@ +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import nodata from '@/assets/img/no-data.png' +import { checkSession } from '@/store/cache' +import { closeLoading, showLoading, showMessageError, showMessageOK } from '@/utils/dialog' +import { httpGet, httpPost } from '@/utils/http' +import { replaceImg, substr, dateFormat } from '@/utils/libs' +import { ElMessage, ElMessageBox } from 'element-plus' +import { defineStore } from 'pinia' +import { computed, reactive, ref } from 'vue' + +export const useJimengStore = defineStore('jimeng', () => { + // 当前激活的功能分类和具体功能 + const activeCategory = ref('image_generation') + const activeFunction = ref('text_to_image') + const useImageInput = ref(false) + + // 共同状态 + const loading = ref(false) + const submitting = ref(false) + const list = ref([]) + const noData = ref(true) + const page = ref(1) + const pageSize = ref(20) + const total = ref(0) + const taskPulling = ref(false) + const pullHandler = ref(null) + const taskFilter = ref('all') + const currentList = ref([]) + + // 用户信息 + const isLogin = ref(false) + const userPower = ref(100) + + // 视频预览 + const showDialog = ref(false) + const currentVideoUrl = ref('') + + // 功能分类配置 + const categories = [ + { key: 'image_generation', name: '图片生成' }, + { key: 'image_editing', name: 'AI修图' }, + { key: 'image_effects', name: '图像特效' }, + { key: 'video_generation', name: '视频生成' }, + ] + + // 功能配置 + const functions = [ + { key: 'text_to_image', name: '文生图', category: 'image_generation', needsPrompt: true, needsImage: false, power: 20 }, + { key: 'image_to_image_portrait', name: '图生图', category: 'image_generation', needsPrompt: true, needsImage: true, power: 30 }, + { key: 'image_edit', name: '图像编辑', category: 'image_editing', needsPrompt: true, needsImage: true, multiple: true, power: 25 }, + { key: 'image_effects', name: '图像特效', category: 'image_effects', needsPrompt: false, needsImage: true, power: 15 }, + { key: 'text_to_video', name: '文生视频', category: 'video_generation', needsPrompt: true, needsImage: false, power: 100 }, + { key: 'image_to_video', name: '图生视频', category: 'video_generation', needsPrompt: true, needsImage: true, multiple: true, power: 120 }, + ] + + // 各功能的参数 + const textToImageParams = reactive({ + prompt: '', + size: '1328x1328', + scale: 2.5, + seed: -1, + use_pre_llm: false, + }) + + const imageToImageParams = reactive({ + image_input: '', + prompt: '演唱会现场的合照,闪光灯拍摄', + size: '1328x1328', + gpen: 0.4, + skin: 0.3, + skin_unifi: 0, + gen_mode: 'creative', + seed: -1, + }) + + const imageEditParams = reactive({ + image_urls: [], + prompt: '', + scale: 0.5, + seed: -1, + }) + + const imageEffectsParams = reactive({ + image_input1: '', + template_id: '', + size: '1328x1328', + }) + + const textToVideoParams = reactive({ + prompt: '', + aspect_ratio: '16:9', + seed: -1, + }) + + const imageToVideoParams = reactive({ + image_urls: [], + prompt: '', + aspect_ratio: '16:9', + seed: -1, + }) + + // 计算属性 + const currentFunction = computed(() => { + return functions.find(f => f.key === activeFunction.value) || functions[0] + }) + + const currentFunctions = computed(() => { + return functions.filter(f => f.category === activeCategory.value) + }) + + const needsPrompt = computed(() => currentFunction.value.needsPrompt) + const needsImage = computed(() => currentFunction.value.needsImage) + const needsMultipleImages = computed(() => currentFunction.value.multiple) + const currentPowerCost = computed(() => currentFunction.value.power) + + // 初始化方法 + const init = async () => { + try { + const user = await checkSession() + isLogin.value = true + userPower.value = user.power + + // 获取任务列表 + await fetchData(1) + + // 检查是否需要开始轮询 + const pendingCount = await getPendingCount() + if (pendingCount > 0) { + startPolling() + } + } catch (error) { + console.error('初始化失败:', error) + } + } + + // 切换功能分类 + const switchCategory = (category) => { + activeCategory.value = category + const categoryFunctions = functions.filter(f => f.category === category) + if (categoryFunctions.length > 0) { + if (category === 'image_generation') { + activeFunction.value = useImageInput.value ? 'image_to_image_portrait' : 'text_to_image' + } else if (category === 'video_generation') { + activeFunction.value = useImageInput.value ? 'image_to_video' : 'text_to_video' + } else { + activeFunction.value = categoryFunctions[0].key + } + } + } + + // 切换输入模式 + const switchInputMode = () => { + if (activeCategory.value === 'image_generation') { + activeFunction.value = useImageInput.value ? 'image_to_image_portrait' : 'text_to_image' + } else if (activeCategory.value === 'video_generation') { + activeFunction.value = useImageInput.value ? 'image_to_video' : 'text_to_video' + } + } + + // 切换功能 + const switchFunction = (functionKey) => { + activeFunction.value = functionKey + } + + // 获取当前算力消耗 + const getCurrentPowerCost = () => { + return currentFunction.value.power + } + + // 获取功能名称 + const getFunctionName = (type) => { + const func = functions.find(f => f.key === type) + return func ? func.name : type + } + + // 获取任务状态文本 + const getTaskStatusText = (status) => { + const statusMap = { + 'pending': '等待中', + 'processing': '处理中', + 'completed': '已完成', + 'failed': '失败' + } + return statusMap[status] || status + } + + // 获取状态类型 + const getStatusType = (status) => { + const typeMap = { + 'pending': 'info', + 'processing': 'warning', + 'completed': 'success', + 'failed': 'danger' + } + return typeMap[status] || 'info' + } + + // 切换任务筛选 + const switchTaskFilter = (filter) => { + taskFilter.value = filter + updateCurrentList() + } + + // 更新当前列表 + const updateCurrentList = () => { + if (taskFilter.value === 'all') { + currentList.value = list.value + } else if (taskFilter.value === 'image') { + currentList.value = list.value.filter(item => + ['text_to_image', 'image_to_image_portrait', 'image_edit', 'image_effects'].includes(item.type) + ) + } else if (taskFilter.value === 'video') { + currentList.value = list.value.filter(item => + ['text_to_video', 'image_to_video'].includes(item.type) + ) + } + } + + // 获取任务列表 + const fetchData = async (pageNum = 1) => { + try { + loading.value = true + page.value = pageNum + + const response = await httpGet('/api/jimeng/jobs', { + page: pageNum, + page_size: pageSize.value + }) + + if (response.data) { + list.value = response.data.jobs || [] + total.value = response.data.total || 0 + noData.value = list.value.length === 0 + updateCurrentList() + } + } catch (error) { + console.error('获取任务列表失败:', error) + showMessageError('获取任务列表失败') + } finally { + loading.value = false + } + } + + // 提交任务 + const submitTask = async () => { + if (!isLogin.value) { + showMessageError('请先登录') + return + } + + if (userPower.value < currentPowerCost.value) { + showMessageError('算力不足') + return + } + + try { + submitting.value = true + let apiUrl = '' + let requestData = {} + + switch (activeFunction.value) { + case 'text_to_image': + apiUrl = '/api/jimeng/text-to-image' + requestData = { + prompt: textToImageParams.prompt, + width: parseInt(textToImageParams.size.split('x')[0]), + height: parseInt(textToImageParams.size.split('x')[1]), + scale: textToImageParams.scale, + seed: textToImageParams.seed, + use_pre_llm: textToImageParams.use_pre_llm, + } + break + + case 'image_to_image_portrait': + apiUrl = '/api/jimeng/image-to-image-portrait' + requestData = { + image_input: imageToImageParams.image_input, + prompt: imageToImageParams.prompt, + width: parseInt(imageToImageParams.size.split('x')[0]), + height: parseInt(imageToImageParams.size.split('x')[1]), + gpen: imageToImageParams.gpen, + skin: imageToImageParams.skin, + skin_unifi: imageToImageParams.skin_unifi, + gen_mode: imageToImageParams.gen_mode, + seed: imageToImageParams.seed, + } + break + + case 'image_edit': + apiUrl = '/api/jimeng/image-edit' + requestData = { + image_urls: imageEditParams.image_urls, + prompt: imageEditParams.prompt, + scale: imageEditParams.scale, + seed: imageEditParams.seed, + } + break + + case 'image_effects': + apiUrl = '/api/jimeng/image-effects' + requestData = { + image_input1: imageEffectsParams.image_input1, + template_id: imageEffectsParams.template_id, + width: parseInt(imageEffectsParams.size.split('x')[0]), + height: parseInt(imageEffectsParams.size.split('x')[1]), + } + break + + case 'text_to_video': + apiUrl = '/api/jimeng/text-to-video' + requestData = { + prompt: textToVideoParams.prompt, + aspect_ratio: textToVideoParams.aspect_ratio, + seed: textToVideoParams.seed, + } + break + + case 'image_to_video': + apiUrl = '/api/jimeng/image-to-video' + requestData = { + image_urls: imageToVideoParams.image_urls, + prompt: imageToVideoParams.prompt, + aspect_ratio: imageToVideoParams.aspect_ratio, + seed: imageToVideoParams.seed, + } + break + } + + const response = await httpPost(apiUrl, requestData) + + if (response.data) { + showMessageOK('任务提交成功') + // 重新获取任务列表 + await fetchData(1) + // 开始轮询 + startPolling() + } + } catch (error) { + console.error('提交任务失败:', error) + showMessageError(error.message || '提交任务失败') + } finally { + submitting.value = false + } + } + + // 获取待处理任务数量 + const getPendingCount = async () => { + try { + const response = await httpGet('/api/jimeng/pending-count') + return response.data?.count || 0 + } catch (error) { + console.error('获取待处理任务数量失败:', error) + return 0 + } + } + + // 开始轮询 + const startPolling = () => { + if (taskPulling.value) return + + taskPulling.value = true + pullHandler.value = setInterval(async () => { + const pendingCount = await getPendingCount() + if (pendingCount > 0) { + await fetchData(page.value) + } else { + stopPolling() + } + }, 3000) + } + + // 停止轮询 + const stopPolling = () => { + if (pullHandler.value) { + clearInterval(pullHandler.value) + pullHandler.value = null + } + taskPulling.value = false + } + + // 重试任务 + const retryTask = async (taskId) => { + try { + const response = await httpPost(`/api/jimeng/retry/${taskId}`) + if (response.data) { + showMessageOK('重试任务已提交') + await fetchData(page.value) + startPolling() + } + } catch (error) { + console.error('重试任务失败:', error) + showMessageError(error.message || '重试任务失败') + } + } + + // 删除任务 + const removeJob = async (item) => { + try { + await ElMessageBox.confirm('确定要删除这个任务吗?', '提示', { + confirmButtonText: '确定', + cancelButtonText: '取消', + type: 'warning', + }) + + const response = await httpGet('/api/jimeng/remove', { id: item.id }) + if (response.data) { + showMessageOK('删除成功') + await fetchData(page.value) + } + } catch (error) { + if (error !== 'cancel') { + console.error('删除任务失败:', error) + showMessageError(error.message || '删除任务失败') + } + } + } + + // 播放视频 + const playVideo = (item) => { + currentVideoUrl.value = item.video_url + showDialog.value = true + } + + // 下载文件 + const downloadFile = (item) => { + const url = item.video_url || item.img_url + if (url) { + const link = document.createElement('a') + link.href = url + link.download = `jimeng_${item.id}.${item.video_url ? 'mp4' : 'jpg'}` + link.click() + } + } + + // 清理 + const cleanup = () => { + stopPolling() + } + + // 返回所有状态和方法 + return { + // 状态 + activeCategory, + activeFunction, + useImageInput, + loading, + submitting, + list, + noData, + page, + pageSize, + total, + taskFilter, + currentList, + isLogin, + userPower, + showDialog, + currentVideoUrl, + nodata, + + // 配置 + categories, + functions, + currentFunctions, + + // 参数 + textToImageParams, + imageToImageParams, + imageEditParams, + imageEffectsParams, + textToVideoParams, + imageToVideoParams, + + // 计算属性 + currentFunction, + needsPrompt, + needsImage, + needsMultipleImages, + currentPowerCost, + + // 方法 + init, + switchCategory, + switchFunction, + switchInputMode, + getCurrentPowerCost, + getFunctionName, + getTaskStatusText, + getStatusType, + switchTaskFilter, + updateCurrentList, + fetchData, + submitTask, + getPendingCount, + startPolling, + stopPolling, + retryTask, + removeJob, + playVideo, + downloadFile, + cleanup, + + // 工具函数 + substr, + replaceImg, + } +}) \ No newline at end of file diff --git a/web/src/utils/libs.js b/web/src/utils/libs.js index cdb8695a..1c151507 100644 --- a/web/src/utils/libs.js +++ b/web/src/utils/libs.js @@ -255,3 +255,8 @@ export function isChrome() { const userAgent = navigator.userAgent.toLowerCase() return /chrome/.test(userAgent) && !/edg/.test(userAgent) } + +// 格式化日期时间 +export function formatDateTime(timestamp, format = 'yyyy-MM-dd HH:mm:ss') { + return dateFormat(timestamp, format) +} diff --git a/web/src/views/Home.vue b/web/src/views/Home.vue index 35e575b2..0ad5812e 100644 --- a/web/src/views/Home.vue +++ b/web/src/views/Home.vue @@ -69,7 +69,7 @@ +
+ + 登录 + +
-
+
@@ -281,7 +286,9 @@ const logout = function () { httpGet('/api/user/logout') .then(() => { removeUserToken() - router.push('/login') + // 刷新组件 + routerViewKey.value += 1 + loginUser.value = {} }) .catch(() => { ElMessage.error('注销失败!') diff --git a/web/src/views/Index.vue b/web/src/views/Index.vue index 9023e06f..61354a0e 100644 --- a/web/src/views/Index.vue +++ b/web/src/views/Index.vue @@ -69,7 +69,7 @@ class="nav-item-box" @click="router.push(item.url)" > - +
{{ item.name }}
@@ -107,20 +107,6 @@ const githubURL = ref(import.meta.env.VITE_GITHUB_URL) const giteeURL = ref(import.meta.env.VITE_GITEE_URL) const navs = ref([]) -const iconMap = ref({ - '/chat': 'icon-chat', - '/mj': 'icon-mj', - '/sd': 'icon-sd', - '/dalle': 'icon-dalle', - '/images-wall': 'icon-image', - '/suno': 'icon-suno', - '/xmind': 'icon-xmind', - '/apps': 'icon-app', - '/member': 'icon-vip-user', - '/invite': 'icon-share', - '/luma': 'icon-luma', -}) - const displayedChars = ref([]) const initAnimation = ref('') let timer = null // 定时器句柄 diff --git a/web/src/views/Jimeng.vue b/web/src/views/Jimeng.vue new file mode 100644 index 00000000..76baacb4 --- /dev/null +++ b/web/src/views/Jimeng.vue @@ -0,0 +1,799 @@ + + + + + \ No newline at end of file diff --git a/web/src/views/Video.vue b/web/src/views/Video.vue index c8281602..bb01df83 100644 --- a/web/src/views/Video.vue +++ b/web/src/views/Video.vue @@ -115,12 +115,12 @@
-
+
循环参考图
-
+
提示词优化
diff --git a/web/src/views/admin/JimengJobs.vue b/web/src/views/admin/JimengJobs.vue new file mode 100644 index 00000000..cf87ceec --- /dev/null +++ b/web/src/views/admin/JimengJobs.vue @@ -0,0 +1,543 @@ + + + + + \ No newline at end of file diff --git a/web/src/views/admin/SysConfig.vue b/web/src/views/admin/SysConfig.vue index cc3b8103..ce28b400 100644 --- a/web/src/views/admin/SysConfig.vue +++ b/web/src/views/admin/SysConfig.vue @@ -169,10 +169,10 @@