add user lock for chat api, Prevent insufficient deduction of user power caused by submitting multiple requests at one time

This commit is contained in:
GeekMaster
2025-09-12 15:05:14 +08:00
parent 65fb58585c
commit c5badb3e13
43 changed files with 309 additions and 429 deletions

View File

@@ -50,37 +50,16 @@ func (h *JimengHandler) RegisterRoutes() {
}
}
// JimengTaskRequest 统一任务请求结构体
// 支持所有生图和生成视频类型
type JimengTaskRequest struct {
TaskType string `json:"task_type" binding:"required"`
Prompt string `json:"prompt"`
ImageInput string `json:"image_input"`
ImageUrls []string `json:"image_urls"`
BinaryDataBase64 []string `json:"binary_data_base64"`
Scale float64 `json:"scale"`
Width int `json:"width"`
Height int `json:"height"`
Gpen float64 `json:"gpen"`
Skin float64 `json:"skin"`
SkinUnifi float64 `json:"skin_unifi"`
GenMode string `json:"gen_mode"`
Seed int64 `json:"seed"`
UsePreLLM bool `json:"use_pre_llm"`
TemplateId string `json:"template_id"`
AspectRatio string `json:"aspect_ratio"`
}
// CreateTask 统一任务创建接口
func (h *JimengHandler) CreateTask(c *gin.Context) {
var req JimengTaskRequest
var req types.JimengTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 文本审核
if h.App.SysConfig.Moderation.Enable {
if h.App.SysConfig.Moderation.Enable && req.Prompt != "" {
moderationResult, err := h.moderationManager.GetService().Moderate(req.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
@@ -103,166 +82,46 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
}
// 新增:除图像特效外,其他任务类型必须有提示词
if req.TaskType != "image_effects" && req.Prompt == "" {
resp.ERROR(c, "提示词不能为空")
if req.Prompt == "" && len(req.ImageUrls) == 0 {
resp.ERROR(c, "提示词和图片不能同时为空")
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if req.Width == 0 {
req.Width = 1328
}
if req.Height == 0 {
req.Height = 1328
}
if req.Seed == 0 {
req.Seed = -1
}
// 获取算力消耗
var powerCost int
var taskType model.JMTaskType
var params map[string]any
var reqKey string
var modelName string
// if user.Power < powerCost {
// resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
// return
// }
switch req.TaskType {
case "text_to_image":
powerCost = h.getPowerFromConfig(model.JMTaskTypeImage)
taskType = model.JMTaskTypeImage
reqKey = jimeng.ReqKeyTextToImage
modelName = "即梦文生图"
if req.Scale == 0 {
req.Scale = 2.5
}
params = map[string]any{
"seed": req.Seed,
"scale": req.Scale,
"width": req.Width,
"height": req.Height,
"use_pre_llm": req.UsePreLLM,
}
case "image_to_image":
powerCost = h.getPowerFromConfig(model.JMTaskTypeVideo)
taskType = model.JMTaskTypeVideo
reqKey = jimeng.ReqKeyImageToImagePortrait
modelName = "即梦图生图"
if req.Gpen == 0 {
req.Gpen = 0.4
}
if req.Skin == 0 {
req.Skin = 0.3
}
if req.GenMode == "" {
if req.Prompt != "" {
req.GenMode = jimeng.GenModeCreative
} else {
req.GenMode = jimeng.GenModeReference
}
}
params = map[string]any{
"image_input": req.ImageInput,
"width": req.Width,
"height": req.Height,
"gpen": req.Gpen,
"skin": req.Skin,
"skin_unifi": req.SkinUnifi,
"gen_mode": req.GenMode,
"seed": req.Seed,
}
case "image_edit":
powerCost = h.getPowerFromConfig(model.JMTaskTypeVirtualHuman)
taskType = model.JMTaskTypeVirtualHuman
reqKey = jimeng.ReqKeyImageEdit
modelName = "即梦图像编辑"
if req.Scale == 0 {
req.Scale = 0.5
}
params = map[string]any{
"seed": req.Seed,
"scale": req.Scale,
}
params["image_urls"] = []string{req.ImageInput}
case "image_effects":
powerCost = h.getPowerFromConfig(model.JMTaskTypeActionTransfer)
taskType = model.JMTaskTypeActionTransfer
reqKey = jimeng.ReqKeyImageEffects
modelName = "即梦图像特效"
if req.Width == 0 {
req.Width = 1328
}
if req.Height == 0 {
req.Height = 1328
}
params = map[string]any{
"image_input1": req.ImageInput,
"template_id": req.TemplateId,
"width": req.Width,
"height": req.Height,
}
case "text_to_video":
powerCost = h.getPowerFromConfig(model.JMTaskTypeVideo)
taskType = model.JMTaskTypeVideo
reqKey = jimeng.ReqKeyTextToVideo
modelName = "即梦文生视频"
if req.AspectRatio == "" {
req.AspectRatio = jimeng.AspectRatio16_9
}
params = map[string]any{
"seed": req.Seed,
"aspect_ratio": req.AspectRatio,
}
case "image_to_video":
powerCost = h.getPowerFromConfig(model.JMTaskTypeVideo)
taskType = model.JMTaskTypeVideo
reqKey = jimeng.ReqKeyImageToVideo
modelName = "即梦图生视频"
params = map[string]any{
"seed": req.Seed,
"aspect_ratio": req.AspectRatio,
}
if len(req.ImageUrls) > 0 {
params["image_urls"] = req.ImageUrls
}
if len(req.BinaryDataBase64) > 0 {
params["binary_data_base64"] = req.BinaryDataBase64
}
default:
resp.ERROR(c, "不支持的任务类型")
return
}
// taskReq := &jimeng.CreateTaskRequest{
// Type: taskType,
// Prompt: req.Prompt,
// Params: params,
// ReqKey: reqKey,
// Power: powerCost,
// }
if user.Power < powerCost {
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
return
}
// job, err := h.jimengService.CreateTask(user.Id, taskReq)
// if err != nil {
// logger.Errorf("create jimeng task failed: %v", err)
// resp.ERROR(c, "创建任务失败")
// return
// }
taskReq := &jimeng.CreateTaskRequest{
Type: taskType,
Prompt: req.Prompt,
Params: params,
ReqKey: reqKey,
Power: powerCost,
}
// h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
// Type: types.PowerConsume,
// Model: "jimeng",
// Remark: fmt.Sprintf("%s任务ID%d", modelName, job.Id),
// })
job, err := h.jimengService.CreateTask(user.Id, taskReq)
if err != nil {
logger.Errorf("create jimeng task failed: %v", err)
resp.ERROR(c, "创建任务失败")
return
}
h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
Type: types.PowerConsume,
Model: "jimeng",
Remark: fmt.Sprintf("%s任务ID%d", modelName, job.Id),
})
resp.SUCCESS(c, job)
resp.SUCCESS(c)
}
// Jobs 获取任务列表
@@ -287,9 +146,9 @@ func (h *JimengHandler) Jobs(c *gin.Context) {
switch req.Filter {
case "image":
query = query.Where("type = ?", model.JMTaskTypeImage)
query = query.Where("type = ?", types.JMTaskTypeImage)
case "video":
query = query.Where("type = ?", model.JMTaskTypeVideo)
query = query.Where("type = ?", types.JMTaskTypeVideo)
}
if len(req.Ids) > 0 {
@@ -349,7 +208,7 @@ func (h *JimengHandler) Remove(c *gin.Context) {
}
// 正在运行中的任务不能删除
if job.Status == model.JMTaskStatusGenerating || job.Status == model.JMTaskStatusInQueue {
if job.Status == types.JMTaskStatusGenerating || job.Status == types.JMTaskStatusInQueue {
resp.ERROR(c, "正在运行中的任务不能删除,否则无法退回算力")
return
}
@@ -362,7 +221,7 @@ func (h *JimengHandler) Remove(c *gin.Context) {
}
// 失败任务删除后退回算力
if job.Status != model.JMTaskStatusFailed {
if job.Status != types.JMTaskStatusFailed {
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "jimeng",
@@ -403,13 +262,13 @@ func (h *JimengHandler) Retry(c *gin.Context) {
}
// 只有失败的任务才能重试
if job.Status != model.JMTaskStatusFailed {
if job.Status != types.JMTaskStatusFailed {
resp.ERROR(c, "只有失败的任务才能重试")
return
}
// 重置任务状态
if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JMTaskStatusInQueue, ""); err != nil {
if err := h.jimengService.UpdateJobStatus(uint(jobId), types.JMTaskStatusInQueue, ""); err != nil {
logger.Errorf("reset job status failed: %v", err)
resp.ERROR(c, "重置任务状态失败")
return
@@ -426,17 +285,17 @@ func (h *JimengHandler) Retry(c *gin.Context) {
}
// getPowerFromConfig 从配置中获取指定类型的算力消耗
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
func (h *JimengHandler) getPowerFromConfig(taskType types.JMTaskType) int {
config := h.App.SysConfig.Jimeng
switch taskType {
case model.JMTaskTypeImage:
case types.JMTaskTypeImage:
return config.Power.Image
case model.JMTaskTypeVideo:
case types.JMTaskTypeVideo:
return config.Power.Video
case model.JMTaskTypeVirtualHuman:
case types.JMTaskTypeVirtualHuman:
return config.Power.VirtualHuman
case model.JMTaskTypeActionTransfer:
case types.JMTaskTypeActionTransfer:
return config.Power.ActionTransfer
default:
return 10