增加即梦AI功能页面

This commit is contained in:
GeekMaster
2025-07-18 18:04:32 +08:00
parent 66776556d8
commit 76d32c78d8
40 changed files with 4511 additions and 118 deletions

View File

@@ -43,9 +43,16 @@ type SmtpConfig struct {
}
type ApiConfig struct {
ApiURL string
AppId string
Token string
ApiURL string
AppId string
Token string
JimengConfig JimengConfig // 即梦AI配置
}
// JimengConfig 即梦AI配置
type JimengConfig struct {
AccessKey string // 火山引擎AccessKey
SecretKey string // 火山引擎SecretKey
}
type AlipayConfig struct {
@@ -170,7 +177,7 @@ type SystemConfig struct {
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的模型 id
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位MB
}

View File

@@ -0,0 +1,177 @@
package admin
import (
"strconv"
"geekai/core"
"geekai/handler"
"geekai/service/jimeng"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
)
// AdminJimengHandler 管理后台即梦AI处理器
type AdminJimengHandler struct {
handler.BaseHandler
jimengService *jimeng.Service
}
// NewAdminJimengHandler 创建管理后台即梦AI处理器
func NewAdminJimengHandler(app *core.AppServer, jimengService *jimeng.Service) *AdminJimengHandler {
return &AdminJimengHandler{
BaseHandler: handler.BaseHandler{App: app},
jimengService: jimengService,
}
}
// Jobs 获取任务列表
func (h *AdminJimengHandler) Jobs(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
userId := h.GetInt(c, "user_id", 0)
taskType := h.GetTrim(c, "type")
status := h.GetTrim(c, "status")
var tasks []model.JimengJob
var total int64
session := h.DB.Model(&model.JimengJob{})
// 构建查询条件
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if taskType != "" {
session = session.Where("type = ?", taskType)
}
if status != "" {
session = session.Where("status = ?", status)
}
// 获取总数
err := session.Count(&total).Error
if err != nil {
resp.ERROR(c, "获取任务数量失败")
return
}
// 获取数据
offset := (page - 1) * pageSize
err = session.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&tasks).Error
if err != nil {
resp.ERROR(c, "获取任务列表失败")
return
}
resp.SUCCESS(c, gin.H{
"jobs": tasks,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// JobDetail 获取任务详情
func (h *AdminJimengHandler) JobDetail(c *gin.Context) {
idStr := c.Param("id")
jobId, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "参数错误")
return
}
var job model.JimengJob
err = h.DB.Where("id = ?", jobId).First(&job).Error
if err != nil {
resp.ERROR(c, "任务不存在")
return
}
resp.SUCCESS(c, job)
}
// Remove 删除任务
func (h *AdminJimengHandler) Remove(c *gin.Context) {
idStr := c.Param("id")
jobId, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
resp.ERROR(c, "参数错误")
return
}
err = h.DB.Where("id = ?", jobId).Delete(&model.JimengJob{}).Error
if err != nil {
resp.ERROR(c, "删除任务失败")
return
}
resp.SUCCESS(c, gin.H{})
}
// BatchRemove 批量删除任务
func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
var req struct {
JobIds []uint `json:"job_ids" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, "参数错误")
return
}
result := h.DB.Where("id IN ?", req.JobIds).Delete(&model.JimengJob{})
if result.Error != nil {
resp.ERROR(c, "批量删除失败")
return
}
resp.SUCCESS(c, gin.H{
"message": "批量删除成功",
"deleted_count": result.RowsAffected,
})
}
// Stats 获取统计信息
func (h *AdminJimengHandler) Stats(c *gin.Context) {
type StatResult struct {
Status string `json:"status"`
Count int64 `json:"count"`
}
var stats []StatResult
err := h.DB.Model(&model.JimengJob{}).
Select("status, COUNT(*) as count").
Group("status").
Find(&stats).Error
if err != nil {
resp.ERROR(c, "获取统计信息失败")
return
}
// 整理统计数据
result := gin.H{
"totalTasks": int64(0),
"completedTasks": int64(0),
"processingTasks": int64(0),
"failedTasks": int64(0),
"pendingTasks": int64(0),
}
for _, stat := range stats {
result["totalTasks"] = result["totalTasks"].(int64) + stat.Count
switch stat.Status {
case "completed":
result["completedTasks"] = stat.Count
case "processing":
result["processingTasks"] = stat.Count
case "failed":
result["failedTasks"] = stat.Count
case "pending":
result["pendingTasks"] = stat.Count
}
}
resp.SUCCESS(c, result)
}

View File

@@ -77,7 +77,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
TranslateModelId: h.App.SysConfig.TranslateModelId,
TranslateModelId: h.App.SysConfig.AssistantModelId,
Power: chatModel.Power,
}
job := model.DallJob{

View File

@@ -213,7 +213,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
Prompt: prompt,
ModelId: 0,
ModelName: "dall-e-3",
TranslateModelId: h.App.SysConfig.TranslateModelId,
TranslateModelId: h.App.SysConfig.AssistantModelId,
N: 1,
Quality: "standard",
Size: "1024x1024",
@@ -265,27 +265,27 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
// 从参数中获取搜索关键词
keyword, ok := params["keyword"].(string)
if !ok || keyword == "" {
resp.ERROR(c, "搜索关键词不能为空")
return
}
// 从参数中获取最大页数默认为1页
maxPages := 1
if pages, ok := params["max_pages"].(float64); ok {
maxPages = int(pages)
}
// 获取用户ID
userID, ok := params["user_id"].(float64)
if !ok {
resp.ERROR(c, "用户ID不能为空")
return
}
// 查询用户信息
var user model.User
res := h.DB.Where("id = ?", int(userID)).First(&user)
@@ -293,21 +293,21 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
resp.ERROR(c, "用户不存在")
return
}
// 检查用户算力是否足够
searchPower := 1 // 每次搜索消耗1点算力
if user.Power < searchPower {
resp.ERROR(c, "算力不足,无法执行网络搜索")
return
}
// 执行网络搜索
searchResults, err := crawler.SearchWeb(keyword, maxPages)
if err != nil {
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
return
}
// 扣减用户算力
err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{
Type: types.PowerConsume,
@@ -318,7 +318,7 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
resp.ERROR(c, "扣减算力失败:"+err.Error())
return
}
// 返回搜索结果
resp.SUCCESS(c, searchResults)
}

View File

@@ -0,0 +1,639 @@
package handler
import (
"fmt"
"strconv"
"time"
"geekai/core"
"geekai/core/types"
"geekai/service/jimeng"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// JimengHandler 即梦AI处理器
type JimengHandler struct {
BaseHandler
jimengService *jimeng.Service
}
// NewJimengHandler 创建即梦AI处理器
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service) *JimengHandler {
return &JimengHandler{
BaseHandler: BaseHandler{App: app},
jimengService: jimengService,
}
}
// TextToImage 文生图
func (h *JimengHandler) TextToImage(c *gin.Context) {
var req struct {
Prompt string `json:"prompt" binding:"required"`
Seed int64 `json:"seed"`
Scale float64 `json:"scale"`
Width int `json:"width"`
Height int `json:"height"`
UsePreLLM bool `json:"use_pre_llm"`
}
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 获取当前用户
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
// 检查用户算力
if user.Power < 20 { // 文生图消耗20算力
resp.ERROR(c, "算力不足")
return
}
// 设置默认参数
if req.Scale == 0 {
req.Scale = 2.5
}
if req.Width == 0 {
req.Width = 1328
}
if req.Height == 0 {
req.Height = 1328
}
if req.Seed == 0 {
req.Seed = -1
}
// 构建任务参数
params := map[string]interface{}{
"seed": req.Seed,
"scale": req.Scale,
"width": req.Width,
"height": req.Height,
"use_pre_llm": req.UsePreLLM,
}
// 创建任务
taskReq := &jimeng.CreateTaskRequest{
Type: model.JimengJobTypeTextToImage,
Prompt: req.Prompt,
Params: params,
ReqKey: model.ReqKeyTextToImage,
Power: 20,
}
job, err := h.jimengService.CreateTask(user.Id, taskReq)
if err != nil {
logger.Errorf("create jimeng text to image task failed: %v", err)
resp.ERROR(c, "创建任务失败")
return
}
// 扣除用户算力
h.subUserPower(user.Id, 20, model.PowerLog{
Type: types.PowerConsume,
Model: "即梦文生图",
Remark: fmt.Sprintf("任务ID%d", job.Id),
})
resp.SUCCESS(c, job)
}
// ImageToImagePortrait 图生图人像写真
func (h *JimengHandler) ImageToImagePortrait(c *gin.Context) {
var req struct {
ImageInput string `json:"image_input" binding:"required"`
Prompt string `json:"prompt"`
Width int `json:"width"`
Height int `json:"height"`
Gpen float64 `json:"gpen"`
Skin float64 `json:"skin"`
SkinUnifi float64 `json:"skin_unifi"`
GenMode string `json:"gen_mode"`
Seed int64 `json:"seed"`
}
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, "参数错误: "+err.Error())
return
}
// 获取当前用户
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
// 检查用户算力
if user.Power < 30 { // 图生图消耗30算力
resp.ERROR(c, "算力不足")
return
}
// 设置默认参数
if req.Width == 0 {
req.Width = 1328
}
if req.Height == 0 {
req.Height = 1328
}
if req.Gpen == 0 {
req.Gpen = 0.4
}
if req.Skin == 0 {
req.Skin = 0.3
}
if req.GenMode == "" {
if req.Prompt != "" {
req.GenMode = jimeng.GenModeCreative
} else {
req.GenMode = jimeng.GenModeReference
}
}
if req.Seed == 0 {
req.Seed = -1
}
if req.Prompt == "" {
req.Prompt = "演唱会现场的合照,闪光灯拍摄"
}
// 构建任务参数
params := map[string]interface{}{
"image_input": req.ImageInput,
"width": req.Width,
"height": req.Height,
"gpen": req.Gpen,
"skin": req.Skin,
"skin_unifi": req.SkinUnifi,
"gen_mode": req.GenMode,
"seed": req.Seed,
}
// 创建任务
taskReq := &jimeng.CreateTaskRequest{
Type: model.JimengJobTypeImageToImagePortrait,
Prompt: req.Prompt,
Params: params,
ReqKey: model.ReqKeyImageToImagePortrait,
Power: 30,
}
job, err := h.jimengService.CreateTask(user.Id, taskReq)
if err != nil {
logger.Errorf("create jimeng image to image portrait task failed: %v", err)
resp.ERROR(c, "创建任务失败")
return
}
// 扣除用户算力
h.subUserPower(user.Id, 30, model.PowerLog{
Type: types.PowerConsume,
Model: "即梦图生图",
Remark: fmt.Sprintf("任务ID%d", job.Id),
})
resp.SUCCESS(c, job)
}
// ImageEdit 图像编辑
func (h *JimengHandler) ImageEdit(c *gin.Context) {
var req struct {
ImageUrls []string `json:"image_urls"`
BinaryDataBase64 []string `json:"binary_data_base64"`
Prompt string `json:"prompt" binding:"required"`
Seed int64 `json:"seed"`
Scale float64 `json:"scale"`
}
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, "参数错误: "+err.Error())
return
}
if len(req.ImageUrls) == 0 && len(req.BinaryDataBase64) == 0 {
resp.ERROR(c, "请提供图片URL或Base64数据")
return
}
// 获取当前用户
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
// 检查用户算力
if user.Power < 25 { // 图像编辑消耗25算力
resp.ERROR(c, "算力不足")
return
}
// 设置默认参数
if req.Scale == 0 {
req.Scale = 0.5
}
if req.Seed == 0 {
req.Seed = -1
}
// 构建任务参数
params := map[string]interface{}{
"seed": req.Seed,
"scale": req.Scale,
}
if len(req.ImageUrls) > 0 {
params["image_urls"] = req.ImageUrls
}
if len(req.BinaryDataBase64) > 0 {
params["binary_data_base64"] = req.BinaryDataBase64
}
// 创建任务
taskReq := &jimeng.CreateTaskRequest{
Type: model.JimengJobTypeImageEdit,
Prompt: req.Prompt,
Params: params,
ReqKey: model.ReqKeyImageEdit,
Power: 25,
}
job, err := h.jimengService.CreateTask(user.Id, taskReq)
if err != nil {
logger.Errorf("create jimeng image edit task failed: %v", err)
resp.ERROR(c, "创建任务失败")
return
}
// 扣除用户算力
h.subUserPower(user.Id, 25, model.PowerLog{
Type: types.PowerConsume,
Model: "即梦图像编辑",
Remark: fmt.Sprintf("任务ID%d", job.Id),
})
resp.SUCCESS(c, job)
}
// ImageEffects 图像特效
func (h *JimengHandler) ImageEffects(c *gin.Context) {
var req struct {
ImageInput1 string `json:"image_input1" binding:"required"`
TemplateId string `json:"template_id" binding:"required"`
Width int `json:"width"`
Height int `json:"height"`
}
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, "参数错误: "+err.Error())
return
}
// 获取当前用户
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
// 检查用户算力
if user.Power < 15 { // 图像特效消耗15算力
resp.ERROR(c, "算力不足")
return
}
// 设置默认参数
if req.Width == 0 {
req.Width = 1328
}
if req.Height == 0 {
req.Height = 1328
}
// 构建任务参数
params := map[string]interface{}{
"image_input1": req.ImageInput1,
"template_id": req.TemplateId,
"width": req.Width,
"height": req.Height,
}
// 创建任务
taskReq := &jimeng.CreateTaskRequest{
Type: model.JimengJobTypeImageEffects,
Prompt: "",
Params: params,
ReqKey: model.ReqKeyImageEffects,
Power: 15,
}
job, err := h.jimengService.CreateTask(user.Id, taskReq)
if err != nil {
logger.Errorf("create jimeng image effects task failed: %v", err)
resp.ERROR(c, "创建任务失败")
return
}
// 扣除用户算力
h.subUserPower(user.Id, 15, model.PowerLog{
Type: types.PowerConsume,
Model: "即梦图像特效",
Remark: fmt.Sprintf("任务ID%d", job.Id),
})
resp.SUCCESS(c, job)
}
// TextToVideo 文生视频
func (h *JimengHandler) TextToVideo(c *gin.Context) {
var req struct {
Prompt string `json:"prompt" binding:"required"`
Seed int64 `json:"seed"`
AspectRatio string `json:"aspect_ratio"`
}
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, "参数错误: "+err.Error())
return
}
// 获取当前用户
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
// 检查用户算力
if user.Power < 100 { // 文生视频消耗100算力
resp.ERROR(c, "算力不足")
return
}
// 设置默认参数
if req.Seed == 0 {
req.Seed = -1
}
if req.AspectRatio == "" {
req.AspectRatio = jimeng.AspectRatio16_9
}
// 构建任务参数
params := map[string]interface{}{
"seed": req.Seed,
"aspect_ratio": req.AspectRatio,
}
// 创建任务
taskReq := &jimeng.CreateTaskRequest{
Type: model.JimengJobTypeTextToVideo,
Prompt: req.Prompt,
Params: params,
ReqKey: model.ReqKeyTextToVideo,
Power: 100,
}
job, err := h.jimengService.CreateTask(user.Id, taskReq)
if err != nil {
logger.Errorf("create jimeng text to video task failed: %v", err)
resp.ERROR(c, "创建任务失败")
return
}
// 扣除用户算力
h.subUserPower(user.Id, 100, model.PowerLog{
Type: types.PowerConsume,
Model: "即梦文生视频",
Remark: fmt.Sprintf("任务ID%d", job.Id),
})
resp.SUCCESS(c, job)
}
// ImageToVideo 图生视频
func (h *JimengHandler) ImageToVideo(c *gin.Context) {
var req struct {
ImageUrls []string `json:"image_urls"`
BinaryDataBase64 []string `json:"binary_data_base64"`
Prompt string `json:"prompt"`
Seed int64 `json:"seed"`
AspectRatio string `json:"aspect_ratio" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, "参数错误: "+err.Error())
return
}
if len(req.ImageUrls) == 0 && len(req.BinaryDataBase64) == 0 {
resp.ERROR(c, "请提供图片URL或Base64数据")
return
}
// 获取当前用户
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
// 检查用户算力
if user.Power < 120 { // 图生视频消耗120算力
resp.ERROR(c, "算力不足")
return
}
// 设置默认参数
if req.Seed == 0 {
req.Seed = -1
}
// 构建任务参数
params := map[string]interface{}{
"seed": req.Seed,
"aspect_ratio": req.AspectRatio,
}
if len(req.ImageUrls) > 0 {
params["image_urls"] = req.ImageUrls
}
if len(req.BinaryDataBase64) > 0 {
params["binary_data_base64"] = req.BinaryDataBase64
}
// 创建任务
taskReq := &jimeng.CreateTaskRequest{
Type: model.JimengJobTypeImageToVideo,
Prompt: req.Prompt,
Params: params,
ReqKey: model.ReqKeyImageToVideo,
Power: 120,
}
job, err := h.jimengService.CreateTask(user.Id, taskReq)
if err != nil {
logger.Errorf("create jimeng image to video task failed: %v", err)
resp.ERROR(c, "创建任务失败")
return
}
// 扣除用户算力
h.subUserPower(user.Id, 120, model.PowerLog{
Type: types.PowerConsume,
Model: "即梦图生视频",
Remark: fmt.Sprintf("任务ID%d", job.Id),
})
resp.SUCCESS(c, job)
}
// Jobs 获取任务列表
func (h *JimengHandler) Jobs(c *gin.Context) {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
jobs, total, err := h.jimengService.GetUserJobs(user.Id, page, pageSize)
if err != nil {
logger.Errorf("get user jimeng jobs failed: %v", err)
resp.ERROR(c, "获取任务列表失败")
return
}
resp.SUCCESS(c, gin.H{
"jobs": jobs,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// PendingCount 获取未完成任务数量
func (h *JimengHandler) PendingCount(c *gin.Context) {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
count, err := h.jimengService.GetPendingTaskCount(user.Id)
if err != nil {
logger.Errorf("get pending task count failed: %v", err)
resp.ERROR(c, "获取待处理任务数量失败")
return
}
resp.SUCCESS(c, gin.H{"count": count})
}
// Remove 删除任务
func (h *JimengHandler) Remove(c *gin.Context) {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
jobId := h.GetInt(c, "id", 0)
if jobId == 0 {
resp.ERROR(c, "参数错误")
return
}
if err := h.jimengService.DeleteJob(uint(jobId), user.Id); err != nil {
logger.Errorf("delete jimeng job failed: %v", err)
resp.ERROR(c, "删除任务失败")
return
}
resp.SUCCESS(c, gin.H{})
}
// Retry 重试任务
func (h *JimengHandler) Retry(c *gin.Context) {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
jobIdStr := c.Param("id")
jobId, err := strconv.ParseUint(jobIdStr, 10, 32)
if err != nil {
resp.ERROR(c, "参数错误")
return
}
// 检查任务是否存在且属于当前用户
job, err := h.jimengService.GetJob(uint(jobId))
if err != nil {
resp.ERROR(c, "任务不存在")
return
}
if job.UserId != user.Id {
resp.ERROR(c, "无权限操作")
return
}
// 只有失败的任务才能重试
if job.Status != model.JimengJobStatusFailed {
resp.ERROR(c, "只有失败的任务才能重试")
return
}
// 重置任务状态
if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JimengJobStatusPending, ""); err != nil {
logger.Errorf("reset job status failed: %v", err)
resp.ERROR(c, "重置任务状态失败")
return
}
// 重新推送到队列
task := map[string]interface{}{
"job_id": jobId,
"type": job.Type,
}
if err := h.jimengService.PushTaskToQueue(task); err != nil {
logger.Errorf("push retry task to queue failed: %v", err)
resp.ERROR(c, "推送重试任务失败")
return
}
resp.SUCCESS(c, gin.H{"message": "重试任务已提交"})
}
// subUserPower 扣除用户算力
func (h *JimengHandler) subUserPower(userId uint, power int, powerLog model.PowerLog) {
session := h.DB.Session(&gorm.Session{})
// 更新用户算力
if err := session.Model(&model.User{}).Where("id = ?", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error; err != nil {
logger.Errorf("update user power failed: %v", err)
return
}
// 记录算力消费日志
powerLog.UserId = userId
powerLog.Amount = power
powerLog.Mark = types.PowerSub
powerLog.CreatedAt = time.Now()
if err := session.Create(&powerLog).Error; err != nil {
logger.Errorf("create power log failed: %v", err)
return
}
session.Commit()
}

View File

@@ -160,7 +160,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
UserId: userId,
ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode,
TranslateModelId: h.App.SysConfig.TranslateModelId,
TranslateModelId: h.App.SysConfig.AssistantModelId,
}
job := model.MidJourneyJob{
Type: data.TaskType,

View File

@@ -48,7 +48,7 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -79,7 +79,7 @@ func (h *PromptHandler) Image(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -108,7 +108,7 @@ func (h *PromptHandler) Video(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -158,9 +158,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
}
func (h *PromptHandler) getPromptModel() string {
if h.App.SysConfig.TranslateModelId > 0 {
if h.App.SysConfig.AssistantModelId > 0 {
var chatModel model.ChatModel
h.DB.Where("id", h.App.SysConfig.TranslateModelId).First(&chatModel)
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel)
return chatModel.Value
}
return "gpt-4o"

View File

@@ -131,7 +131,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
HdSteps: data.HdSteps,
},
UserId: userId,
TranslateModelId: h.App.SysConfig.TranslateModelId,
TranslateModelId: h.App.SysConfig.AssistantModelId,
}
job := model.SdJob{

View File

@@ -85,7 +85,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
TranslateModelId: h.App.SysConfig.TranslateModelId,
TranslateModelId: h.App.SysConfig.AssistantModelId,
}
// 插入数据库
job := model.VideoJob{
@@ -181,7 +181,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
Type: types.VideoKeLing,
Prompt: data.Prompt,
Params: params,
TranslateModelId: h.App.SysConfig.TranslateModelId,
TranslateModelId: h.App.SysConfig.AssistantModelId,
Channel: data.Channel,
}
// 插入数据库

View File

@@ -17,6 +17,7 @@ import (
logger2 "geekai/logger"
"geekai/service"
"geekai/service/dalle"
"geekai/service/jimeng"
"geekai/service/mj"
"geekai/service/oss"
"geekai/service/payment"
@@ -140,6 +141,7 @@ func main() {
fx.Provide(handler.NewProductHandler),
fx.Provide(handler.NewConfigHandler),
fx.Provide(handler.NewPowerLogHandler),
fx.Provide(handler.NewJimengHandler),
fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler),
@@ -153,6 +155,9 @@ func main() {
fx.Provide(admin.NewOrderHandler),
fx.Provide(admin.NewChatHandler),
fx.Provide(admin.NewPowerLogHandler),
fx.Provide(func(app *core.AppServer, service *jimeng.Service) *admin.AdminJimengHandler {
return admin.NewAdminJimengHandler(app, service)
}),
// 创建服务
fx.Provide(sms.NewSendServiceManager),
@@ -203,6 +208,17 @@ func main() {
s.SyncTaskProgress()
s.DownloadFiles()
}),
// 即梦AI 服务
fx.Provide(func(config *types.AppConfig) *jimeng.Client {
return jimeng.NewClient(config.ApiConfig.JimengConfig.AccessKey, config.ApiConfig.JimengConfig.SecretKey)
}),
fx.Provide(jimeng.NewService),
fx.Provide(jimeng.NewConsumer),
fx.Invoke(func(consumer *jimeng.Consumer) {
consumer.Start()
go consumer.MonitorQueue()
}),
fx.Provide(service.NewUserService),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
@@ -496,6 +512,29 @@ func main() {
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
// 即梦AI 路由
fx.Invoke(func(s *core.AppServer, h *handler.JimengHandler) {
group := s.Engine.Group("/api/jimeng")
group.POST("text-to-image", h.TextToImage)
group.POST("image-to-image-portrait", h.ImageToImagePortrait)
group.POST("image-edit", h.ImageEdit)
group.POST("image-effects", h.ImageEffects)
group.POST("text-to-video", h.TextToVideo)
group.POST("image-to-video", h.ImageToVideo)
group.GET("jobs", h.Jobs)
group.GET("pending-count", h.PendingCount)
group.GET("remove", h.Remove)
group.POST("retry/:id", h.Retry)
}),
fx.Invoke(func(s *core.AppServer, h *admin.AdminJimengHandler) {
group := s.Engine.Group("/api/admin/jimeng")
group.GET("jobs", h.Jobs)
group.GET("job/:id", h.JobDetail)
group.DELETE("job/:id", h.Remove)
group.POST("batch-remove", h.BatchRemove)
group.GET("stats", h.Stats)
}),
fx.Provide(admin.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
group := s.Engine.Group("/api/admin/app/type")

View File

@@ -49,7 +49,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.DallTask) {
logger.Infof("add a new DALL-E task to the task list: %+v", task)
s.taskQueue.RPush(task)
if err := s.taskQueue.RPush(task); err != nil {
logger.Errorf("push dall-e task to queue failed: %v", err)
}
}
func (s *Service) Run() {

View File

@@ -0,0 +1,332 @@
package jimeng
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
"time"
"geekai/logger"
)
var clientLogger = logger.GetLogger()
// Client 即梦API客户端
type Client struct {
accessKey string
secretKey string
region string
service string
baseURL string
httpClient *http.Client
}
// NewClient 创建即梦API客户端
func NewClient(accessKey, secretKey string) *Client {
return &Client{
accessKey: accessKey,
secretKey: secretKey,
region: "cn-north-1",
service: "cv",
baseURL: "https://visual.volcengineapi.com",
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// SubmitTask 提交任务
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
// 构建请求URL
queryParams := map[string]string{
"Action": "CVSync2AsyncSubmitTask",
"Version": "2022-08-31",
}
reqURL := c.buildURL(queryParams)
// 序列化请求体
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request body failed: %w", err)
}
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("create http request failed: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
// 签名请求
if err := c.signRequest(httpReq, reqBody); err != nil {
return nil, fmt.Errorf("sign request failed: %w", err)
}
// 发送请求
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("send http request failed: %w", err)
}
defer resp.Body.Close()
// 读取响应
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body failed: %w", err)
}
clientLogger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
// 解析响应
var result SubmitTaskResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("unmarshal response failed: %w", err)
}
return &result, nil
}
// QueryTask 查询任务
func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
// 构建请求URL
queryParams := map[string]string{
"Action": "CVSync2AsyncGetResult",
"Version": "2022-08-31",
}
reqURL := c.buildURL(queryParams)
// 序列化请求体
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request body failed: %w", err)
}
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("create http request failed: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
// 签名请求
if err := c.signRequest(httpReq, reqBody); err != nil {
return nil, fmt.Errorf("sign request failed: %w", err)
}
// 发送请求
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("send http request failed: %w", err)
}
defer resp.Body.Close()
// 读取响应
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body failed: %w", err)
}
clientLogger.Infof("Jimeng QueryTask Response: %s", string(respBody))
// 解析响应
var result QueryTaskResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("unmarshal response failed: %w", err)
}
return &result, nil
}
// SubmitSyncTask 提交同步任务(仅用于文生图)
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
// 构建请求URL
queryParams := map[string]string{
"Action": "CVProcess",
"Version": "2022-08-31",
}
reqURL := c.buildURL(queryParams)
// 序列化请求体
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request body failed: %w", err)
}
// 创建HTTP请求
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("create http request failed: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
// 签名请求
if err := c.signRequest(httpReq, reqBody); err != nil {
return nil, fmt.Errorf("sign request failed: %w", err)
}
// 发送请求
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("send http request failed: %w", err)
}
defer resp.Body.Close()
// 读取响应
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body failed: %w", err)
}
clientLogger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
// 解析响应
var result QueryTaskResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("unmarshal response failed: %w", err)
}
return &result, nil
}
// buildURL 构建请求URL
func (c *Client) buildURL(queryParams map[string]string) string {
u, _ := url.Parse(c.baseURL)
q := u.Query()
for k, v := range queryParams {
q.Set(k, v)
}
u.RawQuery = q.Encode()
return u.String()
}
// signRequest 签名请求
func (c *Client) signRequest(req *http.Request, body []byte) error {
now := time.Now().UTC()
// 设置基本头部
req.Header.Set("X-Date", now.Format("20060102T150405Z"))
req.Header.Set("Host", req.URL.Host)
// 计算内容哈希
contentHash := sha256.Sum256(body)
req.Header.Set("X-Content-Sha256", hex.EncodeToString(contentHash[:]))
// 构建签名字符串
canonicalRequest := c.buildCanonicalRequest(req)
credentialScope := fmt.Sprintf("%s/%s/%s/request", now.Format("20060102"), c.region, c.service)
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
now.Format("20060102T150405Z"), credentialScope, sha256Hash(canonicalRequest))
// 计算签名
signature := c.calculateSignature(stringToSign, now)
// 设置Authorization头部
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
c.accessKey, credentialScope, c.getSignedHeaders(req), signature)
req.Header.Set("Authorization", authorization)
return nil
}
// buildCanonicalRequest 构建规范请求
func (c *Client) buildCanonicalRequest(req *http.Request) string {
// HTTP方法
method := req.Method
// 规范URI
uri := req.URL.Path
if uri == "" {
uri = "/"
}
// 规范查询字符串
query := req.URL.Query()
var queryParts []string
for k, v := range query {
for _, val := range v {
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(val)))
}
}
sort.Strings(queryParts)
canonicalQuery := strings.Join(queryParts, "&")
// 规范头部
var headerParts []string
headers := make(map[string]string)
for k, v := range req.Header {
key := strings.ToLower(k)
if len(v) > 0 {
headers[key] = strings.TrimSpace(v[0])
}
}
var headerKeys []string
for k := range headers {
headerKeys = append(headerKeys, k)
}
sort.Strings(headerKeys)
for _, k := range headerKeys {
headerParts = append(headerParts, fmt.Sprintf("%s:%s", k, headers[k]))
}
canonicalHeaders := strings.Join(headerParts, "\n") + "\n"
// 签名头部
signedHeaders := c.getSignedHeaders(req)
// 载荷哈希
payloadHash := req.Header.Get("X-Content-Sha256")
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
method, uri, canonicalQuery, canonicalHeaders, signedHeaders, payloadHash)
}
// getSignedHeaders 获取签名头部
func (c *Client) getSignedHeaders(req *http.Request) string {
var headers []string
for k := range req.Header {
headers = append(headers, strings.ToLower(k))
}
sort.Strings(headers)
return strings.Join(headers, ";")
}
// calculateSignature 计算签名
func (c *Client) calculateSignature(stringToSign string, t time.Time) string {
kDate := hmacSha256([]byte("HMAC-SHA256"+c.secretKey), []byte(t.Format("20060102")))
kRegion := hmacSha256(kDate, []byte(c.region))
kService := hmacSha256(kRegion, []byte(c.service))
kSigning := hmacSha256(kService, []byte("request"))
signature := hmacSha256(kSigning, []byte(stringToSign))
return hex.EncodeToString(signature)
}
// hmacSha256 计算HMAC-SHA256
func hmacSha256(key []byte, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}
// sha256Hash 计算SHA256哈希
func sha256Hash(data string) string {
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}

View File

@@ -0,0 +1,177 @@
package jimeng
import (
"context"
"time"
"geekai/logger"
"geekai/store/model"
)
var jimengLogger = logger.GetLogger()
// Consumer 即梦任务消费者
type Consumer struct {
service *Service
ctx context.Context
cancel context.CancelFunc
}
// NewConsumer 创建即梦任务消费者
func NewConsumer(service *Service) *Consumer {
ctx, cancel := context.WithCancel(context.Background())
return &Consumer{
service: service,
ctx: ctx,
cancel: cancel,
}
}
// Start 启动消费者
func (c *Consumer) Start() {
jimengLogger.Info("Starting Jimeng task consumer...")
go c.consume()
}
// Stop 停止消费者
func (c *Consumer) Stop() {
jimengLogger.Info("Stopping Jimeng task consumer...")
c.cancel()
}
// consume 消费任务
func (c *Consumer) consume() {
for {
select {
case <-c.ctx.Done():
jimengLogger.Info("Jimeng task consumer stopped")
return
default:
c.processTask()
}
}
}
// processTask 处理任务
func (c *Consumer) processTask() {
// 从队列中获取任务
var task map[string]interface{}
if err := c.service.taskQueue.LPop(&task); err != nil {
// 队列为空等待1秒后重试
time.Sleep(time.Second)
return
}
// 解析任务
jobIdFloat, ok := task["job_id"].(float64)
if !ok {
jimengLogger.Errorf("invalid job_id in task: %v", task)
return
}
jobId := uint(jobIdFloat)
taskType, ok := task["type"].(string)
if !ok {
jimengLogger.Errorf("invalid task type in task: %v", task)
return
}
jimengLogger.Infof("Processing Jimeng task: job_id=%d, type=%s", jobId, taskType)
// 处理任务
if err := c.service.ProcessTask(jobId); err != nil {
jimengLogger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
// 任务失败,直接标记为失败状态,不进行重试
c.service.UpdateJobStatus(jobId, model.JimengJobStatusFailed, err.Error())
} else {
jimengLogger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
}
}
// TaskQueueStatus 任务队列状态
type TaskQueueStatus struct {
QueueLength int `json:"queue_length"`
ActiveTasks int `json:"active_tasks"`
}
// GetQueueStatus 获取队列状态
func (c *Consumer) GetQueueStatus() (*TaskQueueStatus, error) {
// 获取队列长度
length, err := c.service.taskQueue.Size()
if err != nil {
return nil, err
}
// 获取活跃任务数(正在处理的任务)
activeTasks, err := c.service.GetPendingTaskCount(0) // 0表示所有用户
if err != nil {
activeTasks = 0
}
return &TaskQueueStatus{
QueueLength: int(length),
ActiveTasks: int(activeTasks),
}, nil
}
// MonitorQueue 监控队列状态
func (c *Consumer) MonitorQueue() {
ticker := time.NewTicker(30 * time.Second) // 每30秒监控一次
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
status, err := c.GetQueueStatus()
if err != nil {
jimengLogger.Errorf("get queue status failed: %v", err)
continue
}
if status.QueueLength > 0 || status.ActiveTasks > 0 {
jimengLogger.Infof("Jimeng queue status: queue_length=%d, active_tasks=%d",
status.QueueLength, status.ActiveTasks)
}
}
}
}
// PushTaskToQueue 推送任务到队列(用于手动重试)
func (c *Consumer) PushTaskToQueue(task map[string]interface{}) error {
return c.service.taskQueue.RPush(task)
}
// GetTaskStats 获取任务统计信息
func (c *Consumer) GetTaskStats() (map[string]interface{}, error) {
type StatResult struct {
Status string `json:"status"`
Count int64 `json:"count"`
}
var stats []StatResult
err := c.service.db.Model(&model.JimengJob{}).
Select("status, COUNT(*) as count").
Group("status").
Find(&stats).Error
if err != nil {
return nil, err
}
result := map[string]interface{}{
"total": int64(0),
"completed": int64(0),
"processing": int64(0),
"failed": int64(0),
"pending": int64(0),
}
for _, stat := range stats {
result["total"] = result["total"].(int64) + stat.Count
result[stat.Status] = stat.Count
}
return result, nil
}

View File

@@ -0,0 +1,633 @@
package jimeng
import (
"encoding/json"
"fmt"
"strconv"
"time"
"gorm.io/gorm"
"geekai/logger"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
)
var serviceLogger = logger.GetLogger()
// Service 即梦服务
type Service struct {
db *gorm.DB
redis *redis.Client
taskQueue *store.RedisQueue
client *Client
}
// NewService 创建即梦服务
func NewService(db *gorm.DB, redisCli *redis.Client, client *Client) *Service {
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
return &Service{
db: db,
redis: redisCli,
taskQueue: taskQueue,
client: client,
}
}
// CreateTask 创建任务
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
// 生成任务ID
taskId := utils.RandString(20)
// 序列化任务参数
paramsJson, err := json.Marshal(req.Params)
if err != nil {
return nil, fmt.Errorf("marshal task params failed: %w", err)
}
// 创建任务记录
job := &model.JimengJob{
UserId: userId,
TaskId: taskId,
Type: req.Type,
ReqKey: req.ReqKey,
Prompt: req.Prompt,
TaskParams: string(paramsJson),
Status: model.JimengJobStatusPending,
Power: req.Power,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 保存到数据库
if err := s.db.Create(job).Error; err != nil {
return nil, fmt.Errorf("create jimeng job failed: %w", err)
}
// 推送到任务队列
task := map[string]any{
"job_id": job.Id,
"type": job.Type,
}
if err := s.taskQueue.RPush(task); err != nil {
return nil, fmt.Errorf("push jimeng task to queue failed: %w", err)
}
return job, nil
}
// ProcessTask 处理任务
func (s *Service) ProcessTask(jobId uint) error {
// 获取任务记录
var job model.JimengJob
if err := s.db.First(&job, jobId).Error; err != nil {
return fmt.Errorf("get jimeng job failed: %w", err)
}
// 更新任务状态为处理中
if err := s.UpdateJobStatus(job.Id, model.JimengJobStatusProcessing, ""); err != nil {
return fmt.Errorf("update job status failed: %w", err)
}
// 根据任务类型处理
switch job.Type {
case model.JimengJobTypeTextToImage:
return s.processTextToImage(&job)
case model.JimengJobTypeImageToImagePortrait:
return s.processImageToImagePortrait(&job)
case model.JimengJobTypeImageEdit:
return s.processImageEdit(&job)
case model.JimengJobTypeImageEffects:
return s.processImageEffects(&job)
case model.JimengJobTypeTextToVideo:
return s.processTextToVideo(&job)
case model.JimengJobTypeImageToVideo:
return s.processImageToVideo(&job)
default:
return fmt.Errorf("unsupported task type: %s", job.Type)
}
}
// processTextToImage 处理文生图任务
func (s *Service) processTextToImage(job *model.JimengJob) error {
// 解析任务参数
var params map[string]any
if err := json.Unmarshal([]byte(job.TaskParams), &params); err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
}
// 构建请求
req := &SubmitTaskRequest{
ReqKey: job.ReqKey,
Prompt: job.Prompt,
}
// 设置参数
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
if scale, ok := params["scale"]; ok {
if scaleVal, ok := scale.(float64); ok {
req.Scale = scaleVal
}
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
}
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
if usePreLlm, ok := params["use_pre_llm"]; ok {
if usePreLlmVal, ok := usePreLlm.(bool); ok {
req.UsePreLLM = usePreLlmVal
}
}
// 提交异步任务
resp, err := s.client.SubmitTask(req)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
}
if resp.Code != 10000 {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
}
// 更新任务ID和原始数据
rawData, _ := json.Marshal(resp)
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
"task_id": resp.Data.TaskId,
"raw_data": string(rawData),
"updated_at": time.Now(),
}).Error; err != nil {
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
}
// 开始轮询任务状态
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
}
// processImageToImagePortrait 处理图生图人像写真任务
func (s *Service) processImageToImagePortrait(job *model.JimengJob) error {
// 解析任务参数
var params map[string]any
if err := json.Unmarshal([]byte(job.TaskParams), &params); err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
}
// 构建请求
req := &SubmitTaskRequest{
ReqKey: job.ReqKey,
Prompt: job.Prompt,
}
// 设置图像输入
if imageInput, ok := params["image_input"].(string); ok {
req.ImageInput = imageInput
}
// 设置其他参数
if gpen, ok := params["gpen"]; ok {
if gpenVal, ok := gpen.(float64); ok {
req.Gpen = gpenVal
}
}
if skin, ok := params["skin"]; ok {
if skinVal, ok := skin.(float64); ok {
req.Skin = skinVal
}
}
if skinUnifi, ok := params["skin_unifi"]; ok {
if skinUnifiVal, ok := skinUnifi.(float64); ok {
req.SkinUnifi = skinUnifiVal
}
}
if genMode, ok := params["gen_mode"].(string); ok {
req.GenMode = genMode
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
}
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
// 提交异步任务
resp, err := s.client.SubmitTask(req)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
}
if resp.Code != 10000 {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
}
// 更新任务ID和原始数据
rawData, _ := json.Marshal(resp)
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
"task_id": resp.Data.TaskId,
"raw_data": string(rawData),
"updated_at": time.Now(),
}).Error; err != nil {
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
}
// 开始轮询任务状态
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
}
// processImageEdit 处理图像编辑任务
func (s *Service) processImageEdit(job *model.JimengJob) error {
// 解析任务参数
var params map[string]any
if err := json.Unmarshal([]byte(job.TaskParams), &params); err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
}
// 构建请求
req := &SubmitTaskRequest{
ReqKey: job.ReqKey,
Prompt: job.Prompt,
}
// 设置图像输入
if imageUrls, ok := params["image_urls"].([]any); ok {
for _, url := range imageUrls {
if urlStr, ok := url.(string); ok {
req.ImageUrls = append(req.ImageUrls, urlStr)
}
}
}
if binaryData, ok := params["binary_data_base64"].([]any); ok {
for _, data := range binaryData {
if dataStr, ok := data.(string); ok {
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
}
}
}
// 设置其他参数
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
if scale, ok := params["scale"]; ok {
if scaleVal, ok := scale.(float64); ok {
req.Scale = scaleVal
}
}
// 提交异步任务
resp, err := s.client.SubmitTask(req)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
}
if resp.Code != 10000 {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
}
// 更新任务ID和原始数据
rawData, _ := json.Marshal(resp)
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
"task_id": resp.Data.TaskId,
"raw_data": string(rawData),
"updated_at": time.Now(),
}).Error; err != nil {
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
}
// 开始轮询任务状态
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
}
// processImageEffects 处理图像特效任务
func (s *Service) processImageEffects(job *model.JimengJob) error {
// 解析任务参数
var params map[string]any
if err := json.Unmarshal([]byte(job.TaskParams), &params); err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
}
// 构建请求
req := &SubmitTaskRequest{
ReqKey: job.ReqKey,
}
// 设置图像输入
if imageInput1, ok := params["image_input1"].(string); ok {
req.ImageInput1 = imageInput1
}
if templateId, ok := params["template_id"].(string); ok {
req.TemplateId = templateId
}
if width, ok := params["width"]; ok {
if widthVal, ok := width.(float64); ok {
req.Width = int(widthVal)
}
}
if height, ok := params["height"]; ok {
if heightVal, ok := height.(float64); ok {
req.Height = int(heightVal)
}
}
// 提交异步任务
resp, err := s.client.SubmitTask(req)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
}
if resp.Code != 10000 {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
}
// 更新任务ID和原始数据
rawData, _ := json.Marshal(resp)
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
"task_id": resp.Data.TaskId,
"raw_data": string(rawData),
"updated_at": time.Now(),
}).Error; err != nil {
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
}
// 开始轮询任务状态
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
}
// processTextToVideo 处理文生视频任务
func (s *Service) processTextToVideo(job *model.JimengJob) error {
// 解析任务参数
var params map[string]any
if err := json.Unmarshal([]byte(job.TaskParams), &params); err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
}
// 构建请求
req := &SubmitTaskRequest{
ReqKey: job.ReqKey,
Prompt: job.Prompt,
}
// 设置参数
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
req.AspectRatio = aspectRatio
}
// 提交异步任务
resp, err := s.client.SubmitTask(req)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
}
if resp.Code != 10000 {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
}
// 更新任务ID和原始数据
rawData, _ := json.Marshal(resp)
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
"task_id": resp.Data.TaskId,
"raw_data": string(rawData),
"updated_at": time.Now(),
}).Error; err != nil {
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
}
// 开始轮询任务状态
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
}
// processImageToVideo 处理图生视频任务
func (s *Service) processImageToVideo(job *model.JimengJob) error {
// 解析任务参数
var params map[string]any
if err := json.Unmarshal([]byte(job.TaskParams), &params); err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("parse task params failed: %v", err))
}
// 构建请求
req := &SubmitTaskRequest{
ReqKey: job.ReqKey,
Prompt: job.Prompt,
}
// 设置图像输入
if imageUrls, ok := params["image_urls"].([]any); ok {
for _, url := range imageUrls {
if urlStr, ok := url.(string); ok {
req.ImageUrls = append(req.ImageUrls, urlStr)
}
}
}
if binaryData, ok := params["binary_data_base64"].([]any); ok {
for _, data := range binaryData {
if dataStr, ok := data.(string); ok {
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
}
}
}
// 设置其他参数
if seed, ok := params["seed"]; ok {
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
req.Seed = seedVal
}
}
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
req.AspectRatio = aspectRatio
}
// 提交异步任务
resp, err := s.client.SubmitTask(req)
if err != nil {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
}
if resp.Code != 10000 {
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
}
// 更新任务ID和原始数据
rawData, _ := json.Marshal(resp)
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
"task_id": resp.Data.TaskId,
"raw_data": string(rawData),
"updated_at": time.Now(),
}).Error; err != nil {
serviceLogger.Errorf("update jimeng job task_id failed: %v", err)
}
// 开始轮询任务状态
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
}
// pollTaskStatus 轮询任务状态
func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
maxRetries := 60 // 最大重试次数60次 * 5秒 = 5分钟
retryCount := 0
for retryCount < maxRetries {
time.Sleep(5 * time.Second) // 等待5秒
// 查询任务状态
resp, err := s.client.QueryTask(&QueryTaskRequest{
ReqKey: reqKey,
TaskId: taskId,
ReqJson: `{"return_url":true}`,
})
if err != nil {
serviceLogger.Errorf("query jimeng task status failed: %v", err)
retryCount++
continue
}
// 更新原始数据
rawData, _ := json.Marshal(resp)
s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Update("raw_data", string(rawData))
if resp.Code != 10000 {
return s.handleTaskError(jobId, fmt.Sprintf("query task failed: %s", resp.Message))
}
switch resp.Data.Status {
case TaskStatusDone:
// 任务完成,更新结果
updates := map[string]any{
"status": model.JimengJobStatusCompleted,
"progress": 100,
"updated_at": time.Now(),
}
// 设置结果URL
if len(resp.Data.ImageUrls) > 0 {
updates["img_url"] = resp.Data.ImageUrls[0]
}
if resp.Data.VideoUrl != "" {
updates["video_url"] = resp.Data.VideoUrl
}
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
case TaskStatusInQueue:
// 任务在队列中
s.UpdateJobProgress(jobId, 10)
case TaskStatusGenerating:
// 任务处理中
s.UpdateJobProgress(jobId, 50)
case TaskStatusNotFound, TaskStatusExpired:
// 任务未找到或已过期
return s.handleTaskError(jobId, fmt.Sprintf("task not found or expired: %s", resp.Data.Status))
default:
serviceLogger.Warnf("unknown task status: %s", resp.Data.Status)
}
retryCount++
}
// 超时处理
return s.handleTaskError(jobId, "task timeout")
}
// UpdateJobStatus 更新任务状态
func (s *Service) UpdateJobStatus(jobId uint, status, errMsg string) error {
updates := map[string]any{
"status": status,
"updated_at": time.Now(),
}
if errMsg != "" {
updates["err_msg"] = errMsg
}
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
}
// UpdateJobProgress 更新任务进度
func (s *Service) UpdateJobProgress(jobId uint, progress int) error {
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(map[string]any{
"progress": progress,
"updated_at": time.Now(),
}).Error
}
// handleTaskError 处理任务错误
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
serviceLogger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
return s.UpdateJobStatus(jobId, model.JimengJobStatusFailed, errMsg)
}
// GetJob 获取任务
func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
var job model.JimengJob
if err := s.db.First(&job, jobId).Error; err != nil {
return nil, err
}
return &job, nil
}
// GetUserJobs 获取用户任务列表
func (s *Service) GetUserJobs(userId uint, page, pageSize int) ([]*model.JimengJob, int64, error) {
var jobs []*model.JimengJob
var total int64
query := s.db.Model(&model.JimengJob{}).Where("user_id = ?", userId)
// 统计总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&jobs).Error; err != nil {
return nil, 0, err
}
return jobs, total, nil
}
// GetPendingTaskCount 获取用户未完成任务数量
func (s *Service) GetPendingTaskCount(userId uint) (int64, error) {
var count int64
err := s.db.Model(&model.JimengJob{}).Where("user_id = ? AND status IN (?)", userId,
[]string{model.JimengJobStatusPending, model.JimengJobStatusProcessing}).Count(&count).Error
return count, err
}
// DeleteJob 删除任务
func (s *Service) DeleteJob(jobId uint, userId uint) error {
return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error
}
// PushTaskToQueue 推送任务到队列
func (s *Service) PushTaskToQueue(task map[string]interface{}) error {
return s.taskQueue.RPush(task)
}

163
api/service/jimeng/types.go Normal file
View File

@@ -0,0 +1,163 @@
package jimeng
import "time"
// SubmitTaskRequest 提交任务请求
type SubmitTaskRequest struct {
ReqKey string `json:"req_key"`
// 文生图参数
Prompt string `json:"prompt,omitempty"`
Seed int64 `json:"seed,omitempty"`
Scale float64 `json:"scale,omitempty"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
UsePreLLM bool `json:"use_pre_llm,omitempty"`
// 图生图参数
ImageInput string `json:"image_input,omitempty"`
ImageUrls []string `json:"image_urls,omitempty"`
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
Gpen float64 `json:"gpen,omitempty"`
Skin float64 `json:"skin,omitempty"`
SkinUnifi float64 `json:"skin_unifi,omitempty"`
GenMode string `json:"gen_mode,omitempty"`
// 图像编辑参数
// 图像特效参数
ImageInput1 string `json:"image_input1,omitempty"`
TemplateId string `json:"template_id,omitempty"`
// 视频生成参数
AspectRatio string `json:"aspect_ratio,omitempty"`
}
// SubmitTaskResponse 提交任务响应
type SubmitTaskResponse struct {
Code int `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
Status int `json:"status"`
TimeElapsed string `json:"time_elapsed"`
Data struct {
TaskId string `json:"task_id"`
} `json:"data"`
}
// QueryTaskRequest 查询任务请求
type QueryTaskRequest struct {
ReqKey string `json:"req_key"`
TaskId string `json:"task_id"`
ReqJson string `json:"req_json,omitempty"`
}
// QueryTaskResponse 查询任务响应
type QueryTaskResponse struct {
Code int `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
Status int `json:"status"`
TimeElapsed string `json:"time_elapsed"`
Data struct {
AlgorithmBaseResp struct {
StatusCode int `json:"status_code"`
StatusMessage string `json:"status_message"`
} `json:"algorithm_base_resp"`
BinaryDataBase64 []string `json:"binary_data_base64"`
ImageUrls []string `json:"image_urls"`
VideoUrl string `json:"video_url"`
RespData string `json:"resp_data"`
Status string `json:"status"`
LlmResult string `json:"llm_result"`
PeResult string `json:"pe_result"`
PredictTagsResult string `json:"predict_tags_result"`
RephraserResult string `json:"rephraser_result"`
VlmResult string `json:"vlm_result"`
InferCtx interface{} `json:"infer_ctx"`
} `json:"data"`
}
// TaskStatus 任务状态
const (
TaskStatusInQueue = "in_queue" // 任务已提交
TaskStatusGenerating = "generating" // 任务处理中
TaskStatusDone = "done" // 处理完成
TaskStatusNotFound = "not_found" // 任务未找到
TaskStatusExpired = "expired" // 任务已过期
)
// CreateTaskRequest 创建任务请求
type CreateTaskRequest struct {
Type string `json:"type"`
Prompt string `json:"prompt"`
Params map[string]interface{} `json:"params"`
ReqKey string `json:"req_key"`
ImageUrls []string `json:"image_urls,omitempty"`
Power int `json:"power,omitempty"`
}
// TaskInfo 任务信息
type TaskInfo struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
TaskId string `json:"task_id"`
Type string `json:"type"`
ReqKey string `json:"req_key"`
Prompt string `json:"prompt"`
TaskParams string `json:"task_params"`
ImgURL string `json:"img_url"`
VideoURL string `json:"video_url"`
Progress int `json:"progress"`
Status string `json:"status"`
ErrMsg string `json:"err_msg"`
Power int `json:"power"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// LogoInfo 水印信息
type LogoInfo struct {
AddLogo bool `json:"add_logo"`
Position int `json:"position"`
Language int `json:"language"`
Opacity float64 `json:"opacity"`
LogoTextContent string `json:"logo_text_content"`
}
// ReqJsonConfig 查询配置
type ReqJsonConfig struct {
ReturnUrl bool `json:"return_url"`
LogoInfo *LogoInfo `json:"logo_info,omitempty"`
}
// ImageEffectTemplate 图像特效模板
const (
TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格
TemplateIdMyWorld = "my_world" // 像素世界风
TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版
TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风
TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版
TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风
TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风
TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版
TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰
TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣
TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡
TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办
TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里
TemplateIdGlassBall = "glass_ball" // 玻璃球
)
// AspectRatio 视频宽高比
const (
AspectRatio16_9 = "16:9" // 1280×720
AspectRatio9_16 = "9:16" // 720×1280
AspectRatio1_1 = "1:1" // 960×960
AspectRatio4_3 = "4:3" // 960×720
AspectRatio3_4 = "3:4" // 720×960
AspectRatio21_9 = "21:9" // 1680×720
AspectRatio9_21 = "9:21" // 720×1680
)
// GenMode 生成模式
const (
GenModeCreative = "creative" // 提示词模式
GenModeReference = "reference" // 全参考模式
GenModeReferenceChar = "reference_char" // 人物参考模式
)

View File

@@ -212,7 +212,9 @@ func (s *Service) DownloadImages() {
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
if err := s.taskQueue.RPush(task); err != nil {
logger.Errorf("push mj task to queue failed: %v", err)
}
}
// SyncTaskProgress 异步拉取任务

View File

@@ -253,7 +253,9 @@ func (s *Service) checkTaskProgress(apiKey model.ApiKey) (*TaskProgressResp, err
func (s *Service) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
if err := s.taskQueue.RPush(task); err != nil {
logger.Errorf("push sd task to queue failed: %v", err)
}
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务

View File

@@ -51,7 +51,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
func (s *Service) PushTask(task types.SunoTask) {
logger.Infof("add a new Suno task to the task list: %+v", task)
s.taskQueue.RPush(task)
if err := s.taskQueue.RPush(task); err != nil {
logger.Errorf("push suno task to queue failed: %v", err)
}
}
func (s *Service) Run() {

View File

@@ -51,7 +51,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
func (s *Service) PushTask(task types.VideoTask) {
logger.Infof("add a new Video task to the task list: %+v", task)
s.taskQueue.RPush(task)
if err := s.taskQueue.RPush(task); err != nil {
logger.Errorf("push video task to queue failed: %v", err)
}
}
func (s *Service) Run() {

View File

@@ -0,0 +1,58 @@
package model
import (
"time"
)
// JimengJob 即梦AI任务模型
type JimengJob struct {
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
UserId uint `gorm:"column:user_id;type:int;not null;index;comment:用户ID" json:"user_id"`
TaskId string `gorm:"column:task_id;type:varchar(100);not null;index;comment:任务ID" json:"task_id"`
Type string `gorm:"column:type;type:varchar(50);not null;comment:任务类型" json:"type"`
ReqKey string `gorm:"column:req_key;type:varchar(100);comment:请求Key" json:"req_key"`
Prompt string `gorm:"column:prompt;type:text;comment:提示词" json:"prompt"`
TaskParams string `gorm:"column:task_params;type:text;comment:任务参数JSON" json:"task_params"`
ImgURL string `gorm:"column:img_url;type:varchar(1024);comment:图片或封面URL" json:"img_url"`
VideoURL string `gorm:"column:video_url;type:varchar(1024);comment:视频URL" json:"video_url"`
RawData string `gorm:"column:raw_data;type:text;comment:原始API响应" json:"raw_data"`
Progress int `gorm:"column:progress;type:int;default:0;comment:进度百分比" json:"progress"`
Status string `gorm:"column:status;type:varchar(20);default:'pending';comment:任务状态" json:"status"`
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
Power int `gorm:"column:power;type:int;default:0;comment:消耗算力" json:"power"`
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;comment:更新时间" json:"updated_at"`
}
// JimengJobStatus 即梦任务状态常量
const (
JimengJobStatusPending = "pending"
JimengJobStatusProcessing = "processing"
JimengJobStatusCompleted = "completed"
JimengJobStatusFailed = "failed"
)
// JimengJobType 即梦任务类型常量
const (
JimengJobTypeTextToImage = "text_to_image" // 文生图
JimengJobTypeImageToImagePortrait = "image_to_image_portrait" // 图生图人像写真
JimengJobTypeImageEdit = "image_edit" // 图像编辑
JimengJobTypeImageEffects = "image_effects" // 图像特效
JimengJobTypeTextToVideo = "text_to_video" // 文生视频
JimengJobTypeImageToVideo = "image_to_video" // 图生视频
)
// ReqKey 常量定义
const (
ReqKeyTextToImage = "high_aes_general_v30l_zt2i" // 文生图
ReqKeyImageToImagePortrait = "i2i_portrait_photo" // 图生图人像写真
ReqKeyImageEdit = "seededit_v3.0" // 图像编辑
ReqKeyImageEffects = "i2i_multi_style_zx2x" // 图像特效
ReqKeyTextToVideo = "jimeng_vgfm_t2v_l20" // 文生视频
ReqKeyImageToVideo = "jimeng_vgfm_i2v_l20" // 图生视频
)
// TableName 返回数据表名称
func (JimengJob) TableName() string {
return "chatgpt_jimeng_jobs"
}

View File

@@ -10,6 +10,7 @@ package store
import (
"context"
"geekai/utils"
"github.com/go-redis/redis/v8"
)
@@ -23,15 +24,15 @@ func NewRedisQueue(name string, client *redis.Client) *RedisQueue {
return &RedisQueue{name: name, client: client, ctx: context.Background()}
}
func (q *RedisQueue) RPush(value interface{}) {
q.client.RPush(q.ctx, q.name, utils.JsonEncode(value))
func (q *RedisQueue) RPush(value any) error {
return q.client.RPush(q.ctx, q.name, utils.JsonEncode(value)).Err()
}
func (q *RedisQueue) LPush(value interface{}) {
q.client.LPush(q.ctx, q.name, utils.JsonEncode(value))
func (q *RedisQueue) LPush(value any) error {
return q.client.LPush(q.ctx, q.name, utils.JsonEncode(value)).Err()
}
func (q *RedisQueue) LPop(value interface{}) error {
func (q *RedisQueue) LPop(value any) error {
result, err := q.client.BLPop(q.ctx, 0, q.name).Result()
if err != nil {
return err
@@ -39,10 +40,18 @@ func (q *RedisQueue) LPop(value interface{}) error {
return utils.JsonDecode(result[1], value)
}
func (q *RedisQueue) RPop(value interface{}) error {
func (q *RedisQueue) RPop(value any) error {
result, err := q.client.BRPop(q.ctx, 0, q.name).Result()
if err != nil {
return err
}
return utils.JsonDecode(result[1], value)
}
func (q *RedisQueue) Size() (int64, error) {
return q.client.LLen(q.ctx, q.name).Result()
}
func (q *RedisQueue) Clear() error {
return q.client.Del(q.ctx, q.name).Err()
}

View File

@@ -0,0 +1,21 @@
package vo
// JimengJob 即梦AI任务VO
type JimengJob struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
TaskId string `json:"task_id"`
Type string `json:"type"`
ReqKey string `json:"req_key"`
Prompt string `json:"prompt"`
TaskParams string `json:"task_params"`
ImgURL string `json:"img_url"`
VideoURL string `json:"video_url"`
RawData string `json:"raw_data"`
Progress int `json:"progress"`
Status string `json:"status"`
ErrMsg string `json:"err_msg"`
Power int `json:"power"`
CreatedAt int64 `json:"created_at"` // 时间戳
UpdatedAt int64 `json:"updated_at"` // 时间戳
}

View File

@@ -1,55 +0,0 @@
package main
import (
"crypto/rand"
"encoding/hex"
"fmt"
"sync"
)
const (
codeLength = 32 // 兑换码长度
)
var (
codeMap = make(map[string]bool)
mapMutex = &sync.Mutex{}
)
// GenerateUniqueCode 生成唯一兑换码
func GenerateUniqueCode() (string, error) {
for {
code, err := generateCode()
if err != nil {
return "", err
}
mapMutex.Lock()
if !codeMap[code] {
codeMap[code] = true
mapMutex.Unlock()
return code, nil
}
mapMutex.Unlock()
}
}
// generateCode 生成兑换码
func generateCode() (string, error) {
bytes := make([]byte, codeLength/2) // 因为 hex 编码会使长度翻倍
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func main() {
for i := 0; i < 10; i++ {
code, err := GenerateUniqueCode()
if err != nil {
fmt.Println("Error generating code:", err)
return
}
fmt.Println("Generated code:", code)
}
}