mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-05-10 19:54:25 +08:00
add user lock for chat api, Prevent insufficient deduction of user power caused by submitting multiple requests at one time
This commit is contained in:
@@ -131,7 +131,7 @@ func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
|
||||
continue // 跳过不存在的
|
||||
}
|
||||
tx := h.DB.Begin()
|
||||
if job.Status != model.JMTaskStatusSuccess && job.Power > 0 {
|
||||
if job.Status != types.JMTaskStatusSuccess && job.Power > 0 {
|
||||
remark := fmt.Sprintf("任务未成功,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
@@ -172,7 +172,7 @@ func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
|
||||
// Stats 获取统计信息
|
||||
func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
type StatResult struct {
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
Status types.JMTaskStatus `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
@@ -198,13 +198,13 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
for _, stat := range stats {
|
||||
result["totalTasks"] = result["totalTasks"].(int64) + stat.Count
|
||||
switch stat.Status {
|
||||
case model.JMTaskStatusInQueue:
|
||||
case types.JMTaskStatusInQueue:
|
||||
result["pendingTasks"] = stat.Count
|
||||
case model.JMTaskStatusSuccess:
|
||||
case types.JMTaskStatusSuccess:
|
||||
result["completedTasks"] = stat.Count
|
||||
case model.JMTaskStatusGenerating:
|
||||
case types.JMTaskStatusGenerating:
|
||||
result["processingTasks"] = stat.Count
|
||||
case model.JMTaskStatusFailed:
|
||||
case types.JMTaskStatusFailed:
|
||||
result["failedTasks"] = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +69,7 @@ type ChatHandler struct {
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
userLocks *types.UserLockManager
|
||||
}
|
||||
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService, moderationManager *moderation.ServiceManager) *ChatHandler {
|
||||
@@ -80,6 +81,7 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
userLocks: types.NewUserLockManager(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,6 +122,14 @@ func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 用户级并发锁,确保同一用户同时只有一个对话请求
|
||||
if !h.userLocks.TryLock(input.UserId) {
|
||||
pushMessage(c, ChatEventError, "您有一个对话请求正在进行中,请稍后再试或先停止当前生成!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer h.userLocks.Unlock(input.UserId)
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -50,37 +50,16 @@ func (h *JimengHandler) RegisterRoutes() {
|
||||
}
|
||||
}
|
||||
|
||||
// JimengTaskRequest 统一任务请求结构体
|
||||
// 支持所有生图和生成视频类型
|
||||
type JimengTaskRequest struct {
|
||||
TaskType string `json:"task_type" binding:"required"`
|
||||
Prompt string `json:"prompt"`
|
||||
ImageInput string `json:"image_input"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Scale float64 `json:"scale"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Gpen float64 `json:"gpen"`
|
||||
Skin float64 `json:"skin"`
|
||||
SkinUnifi float64 `json:"skin_unifi"`
|
||||
GenMode string `json:"gen_mode"`
|
||||
Seed int64 `json:"seed"`
|
||||
UsePreLLM bool `json:"use_pre_llm"`
|
||||
TemplateId string `json:"template_id"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
// CreateTask 统一任务创建接口
|
||||
func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
var req JimengTaskRequest
|
||||
var req types.JimengTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
if h.App.SysConfig.Moderation.Enable && req.Prompt != "" {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(req.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
@@ -103,166 +82,46 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
// 新增:除图像特效外,其他任务类型必须有提示词
|
||||
if req.TaskType != "image_effects" && req.Prompt == "" {
|
||||
resp.ERROR(c, "提示词不能为空")
|
||||
if req.Prompt == "" && len(req.ImageUrls) == 0 {
|
||||
resp.ERROR(c, "提示词和图片不能同时为空")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
// 获取算力消耗
|
||||
|
||||
var powerCost int
|
||||
var taskType model.JMTaskType
|
||||
var params map[string]any
|
||||
var reqKey string
|
||||
var modelName string
|
||||
// if user.Power < powerCost {
|
||||
// resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
// return
|
||||
// }
|
||||
|
||||
switch req.TaskType {
|
||||
case "text_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImage)
|
||||
taskType = model.JMTaskTypeImage
|
||||
reqKey = jimeng.ReqKeyTextToImage
|
||||
modelName = "即梦文生图"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 2.5
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"use_pre_llm": req.UsePreLLM,
|
||||
}
|
||||
case "image_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeVideo)
|
||||
taskType = model.JMTaskTypeVideo
|
||||
reqKey = jimeng.ReqKeyImageToImagePortrait
|
||||
modelName = "即梦图生图"
|
||||
if req.Gpen == 0 {
|
||||
req.Gpen = 0.4
|
||||
}
|
||||
if req.Skin == 0 {
|
||||
req.Skin = 0.3
|
||||
}
|
||||
if req.GenMode == "" {
|
||||
if req.Prompt != "" {
|
||||
req.GenMode = jimeng.GenModeCreative
|
||||
} else {
|
||||
req.GenMode = jimeng.GenModeReference
|
||||
}
|
||||
}
|
||||
params = map[string]any{
|
||||
"image_input": req.ImageInput,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"gpen": req.Gpen,
|
||||
"skin": req.Skin,
|
||||
"skin_unifi": req.SkinUnifi,
|
||||
"gen_mode": req.GenMode,
|
||||
"seed": req.Seed,
|
||||
}
|
||||
case "image_edit":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeVirtualHuman)
|
||||
taskType = model.JMTaskTypeVirtualHuman
|
||||
reqKey = jimeng.ReqKeyImageEdit
|
||||
modelName = "即梦图像编辑"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 0.5
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
}
|
||||
params["image_urls"] = []string{req.ImageInput}
|
||||
case "image_effects":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeActionTransfer)
|
||||
taskType = model.JMTaskTypeActionTransfer
|
||||
reqKey = jimeng.ReqKeyImageEffects
|
||||
modelName = "即梦图像特效"
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
params = map[string]any{
|
||||
"image_input1": req.ImageInput,
|
||||
"template_id": req.TemplateId,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
}
|
||||
case "text_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeVideo)
|
||||
taskType = model.JMTaskTypeVideo
|
||||
reqKey = jimeng.ReqKeyTextToVideo
|
||||
modelName = "即梦文生视频"
|
||||
if req.AspectRatio == "" {
|
||||
req.AspectRatio = jimeng.AspectRatio16_9
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
case "image_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeVideo)
|
||||
taskType = model.JMTaskTypeVideo
|
||||
reqKey = jimeng.ReqKeyImageToVideo
|
||||
modelName = "即梦图生视频"
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
default:
|
||||
resp.ERROR(c, "不支持的任务类型")
|
||||
return
|
||||
}
|
||||
// taskReq := &jimeng.CreateTaskRequest{
|
||||
// Type: taskType,
|
||||
// Prompt: req.Prompt,
|
||||
// Params: params,
|
||||
// ReqKey: reqKey,
|
||||
// Power: powerCost,
|
||||
// }
|
||||
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
// job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
// if err != nil {
|
||||
// logger.Errorf("create jimeng task failed: %v", err)
|
||||
// resp.ERROR(c, "创建任务失败")
|
||||
// return
|
||||
// }
|
||||
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: taskType,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: reqKey,
|
||||
Power: powerCost,
|
||||
}
|
||||
// h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
|
||||
// Type: types.PowerConsume,
|
||||
// Model: "jimeng",
|
||||
// Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id),
|
||||
// })
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
@@ -287,9 +146,9 @@ func (h *JimengHandler) Jobs(c *gin.Context) {
|
||||
|
||||
switch req.Filter {
|
||||
case "image":
|
||||
query = query.Where("type = ?", model.JMTaskTypeImage)
|
||||
query = query.Where("type = ?", types.JMTaskTypeImage)
|
||||
case "video":
|
||||
query = query.Where("type = ?", model.JMTaskTypeVideo)
|
||||
query = query.Where("type = ?", types.JMTaskTypeVideo)
|
||||
}
|
||||
|
||||
if len(req.Ids) > 0 {
|
||||
@@ -349,7 +208,7 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 正在运行中的任务不能删除
|
||||
if job.Status == model.JMTaskStatusGenerating || job.Status == model.JMTaskStatusInQueue {
|
||||
if job.Status == types.JMTaskStatusGenerating || job.Status == types.JMTaskStatusInQueue {
|
||||
resp.ERROR(c, "正在运行中的任务不能删除,否则无法退回算力")
|
||||
return
|
||||
}
|
||||
@@ -362,7 +221,7 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 失败任务删除后退回算力
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
if job.Status != types.JMTaskStatusFailed {
|
||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
@@ -403,13 +262,13 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 只有失败的任务才能重试
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
if job.Status != types.JMTaskStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能重试")
|
||||
return
|
||||
}
|
||||
|
||||
// 重置任务状态
|
||||
if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JMTaskStatusInQueue, ""); err != nil {
|
||||
if err := h.jimengService.UpdateJobStatus(uint(jobId), types.JMTaskStatusInQueue, ""); err != nil {
|
||||
logger.Errorf("reset job status failed: %v", err)
|
||||
resp.ERROR(c, "重置任务状态失败")
|
||||
return
|
||||
@@ -426,17 +285,17 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
}
|
||||
|
||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
func (h *JimengHandler) getPowerFromConfig(taskType types.JMTaskType) int {
|
||||
config := h.App.SysConfig.Jimeng
|
||||
|
||||
switch taskType {
|
||||
case model.JMTaskTypeImage:
|
||||
case types.JMTaskTypeImage:
|
||||
return config.Power.Image
|
||||
case model.JMTaskTypeVideo:
|
||||
case types.JMTaskTypeVideo:
|
||||
return config.Power.Video
|
||||
case model.JMTaskTypeVirtualHuman:
|
||||
case types.JMTaskTypeVirtualHuman:
|
||||
return config.Power.VirtualHuman
|
||||
case model.JMTaskTypeActionTransfer:
|
||||
case types.JMTaskTypeActionTransfer:
|
||||
return config.Power.ActionTransfer
|
||||
default:
|
||||
return 10
|
||||
|
||||
Reference in New Issue
Block a user