add user lock for chat api, Prevent insufficient deduction of user power caused by submitting multiple requests at one time

This commit is contained in:
GeekMaster
2025-09-12 15:05:14 +08:00
parent 65fb58585c
commit c5badb3e13
43 changed files with 309 additions and 429 deletions

View File

@@ -9,6 +9,7 @@ import (
"gorm.io/gorm"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service/oss"
"geekai/store"
@@ -95,7 +96,7 @@ func (s *Service) processNextTask() {
if err := s.ProcessTask(jobId); err != nil {
logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
s.UpdateJobStatus(jobId, types.JMTaskStatusFailed, err.Error())
} else {
logger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
}
@@ -120,7 +121,7 @@ func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.Jimeng
ReqKey: req.ReqKey,
Prompt: req.Prompt,
Params: string(paramsJson),
Status: model.JMTaskStatusInQueue,
Status: types.JMTaskStatusInQueue,
Power: req.Power,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
@@ -148,7 +149,7 @@ func (s *Service) ProcessTask(jobId uint) error {
}
// 更新任务状态为处理中
if err := s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, ""); err != nil {
if err := s.UpdateJobStatus(job.Id, types.JMTaskStatusGenerating, ""); err != nil {
return fmt.Errorf("update job status failed: %w", err)
}
@@ -199,13 +200,13 @@ func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, er
// 根据任务类型设置特定参数
switch job.Type {
case model.JMTaskTypeImage:
case types.JMTaskTypeImage:
s.setTextToImageParams(req, params)
case model.JMTaskTypeVideo:
case types.JMTaskTypeVideo:
s.setImageToImageParams(req, params)
case model.JMTaskTypeVirtualHuman:
case types.JMTaskTypeVirtualHuman:
s.setImageEditParams(req, params)
case model.JMTaskTypeActionTransfer:
case types.JMTaskTypeActionTransfer:
s.setImageEffectsParams(req, params)
default:
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
@@ -353,7 +354,7 @@ func (s *Service) pollTaskStatus() {
for {
var jobs []model.JimengJob
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
s.db.Where("status IN (?)", []types.JMTaskStatus{types.JMTaskStatusGenerating, types.JMTaskStatusInQueue}).Find(&jobs)
if len(jobs) == 0 {
logger.Debugf("no jimeng task to poll, sleep 10s")
time.Sleep(10 * time.Second)
@@ -389,7 +390,7 @@ func (s *Service) pollTaskStatus() {
}
switch resp.Data.Status {
case model.JMTaskStatusDone:
case types.JMTaskStatusDone:
// 判断任务是否成功
if resp.Message != "Success" {
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
@@ -398,7 +399,7 @@ func (s *Service) pollTaskStatus() {
// 任务完成,更新结果
updates := map[string]any{
"status": model.JMTaskStatusSuccess,
"status": types.JMTaskStatusSuccess,
"updated_at": time.Now(),
}
@@ -421,15 +422,15 @@ func (s *Service) pollTaskStatus() {
}
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
case types.JMTaskStatusInQueue, types.JMTaskStatusGenerating:
// 任务处理中
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
s.UpdateJobStatus(job.Id, types.JMTaskStatusGenerating, "")
case model.JMTaskStatusNotFound:
case types.JMTaskStatusNotFound:
// 任务未找到
s.handleTaskError(job.Id, "task not found")
case model.JMTaskStatusExpired:
case types.JMTaskStatusExpired:
continue
default:
logger.Warnf("unknown task status: %s", resp.Data.Status)
@@ -444,7 +445,7 @@ func (s *Service) pollTaskStatus() {
}
// UpdateJobStatus 更新任务状态
func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg string) error {
func (s *Service) UpdateJobStatus(jobId uint, status types.JMTaskStatus, errMsg string) error {
updates := map[string]any{
"status": status,
"updated_at": time.Now(),
@@ -458,7 +459,7 @@ func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg
// handleTaskError 处理任务错误
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
return s.UpdateJobStatus(jobId, types.JMTaskStatusFailed, errMsg)
}
// PushTaskToQueue 推送任务到队列(用于手动重试)
@@ -469,8 +470,8 @@ func (s *Service) PushTaskToQueue(jobId uint) error {
// GetTaskStats 获取任务统计信息
func (s *Service) GetTaskStats() (map[string]any, error) {
type StatResult struct {
Status string `json:"status"`
Count int64 `json:"count"`
Status types.JMTaskStatus `json:"status"`
Count int64 `json:"count"`
}
var stats []StatResult
@@ -492,7 +493,7 @@ func (s *Service) GetTaskStats() (map[string]any, error) {
for _, stat := range stats {
result["total"] = result["total"].(int64) + stat.Count
result[stat.Status] = stat.Count
result[string(stat.Status)] = stat.Count
}
return result, nil

View File

@@ -1,15 +1,7 @@
package jimeng
import "geekai/store/model"
// ReqKey 常量定义
const (
ReqKeyTextToImage = "high_aes_general_v30l_zt2i" // 文生图
ReqKeyImageToImagePortrait = "i2i_portrait_photo" // 图生图人像写真
ReqKeyImageEdit = "seededit_v3.0" // 图像编辑
ReqKeyImageEffects = "i2i_multi_style_zx2x" // 图像特效
ReqKeyTextToVideo = "jimeng_vgfm_t2v_l20" // 文生视频
ReqKeyImageToVideo = "jimeng_vgfm_i2v_l20" // 图生视频
import (
"geekai/core/types"
)
// SubmitTaskRequest 提交任务请求
@@ -73,7 +65,7 @@ type QueryTaskResponse struct {
ImageUrls []string `json:"image_urls"`
VideoUrl string `json:"video_url"`
RespData string `json:"resp_data"`
Status model.JMTaskStatus `json:"status"`
Status types.JMTaskStatus `json:"status"`
LlmResult string `json:"llm_result"`
PeResult string `json:"pe_result"`
PredictTagsResult string `json:"predict_tags_result"`
@@ -85,61 +77,10 @@ type QueryTaskResponse struct {
// CreateTaskRequest 创建任务请求
type CreateTaskRequest struct {
Type model.JMTaskType `json:"type"`
Type types.JMTaskType `json:"type"`
Prompt string `json:"prompt"`
Params map[string]any `json:"params"`
ReqKey string `json:"req_key"`
ImageUrls []string `json:"image_urls,omitempty"`
Power int `json:"power,omitempty"`
}
// LogoInfo 水印信息
type LogoInfo struct {
AddLogo bool `json:"add_logo"`
Position int `json:"position"`
Language int `json:"language"`
Opacity float64 `json:"opacity"`
LogoTextContent string `json:"logo_text_content"`
}
// ReqJsonConfig 查询配置
type ReqJsonConfig struct {
ReturnUrl bool `json:"return_url"`
LogoInfo *LogoInfo `json:"logo_info,omitempty"`
}
// ImageEffectTemplate 图像特效模板
const (
TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格
TemplateIdMyWorld = "my_world" // 像素世界风
TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版
TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风
TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版
TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风
TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风
TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版
TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰
TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣
TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡
TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办
TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里
TemplateIdGlassBall = "glass_ball" // 玻璃球
)
// AspectRatio 视频宽高比
const (
AspectRatio16_9 = "16:9" // 1280×720
AspectRatio9_16 = "9:16" // 720×1280
AspectRatio1_1 = "1:1" // 960×960
AspectRatio4_3 = "4:3" // 960×720
AspectRatio3_4 = "3:4" // 720×960
AspectRatio21_9 = "21:9" // 1680×720
AspectRatio9_21 = "9:21" // 720×1680
)
// GenMode 生成模式
const (
GenModeCreative = "creative" // 提示词模式
GenModeReference = "reference" // 全参考模式
GenModeReferenceChar = "reference_char" // 人物参考模式
)