mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-05-10 19:54:25 +08:00
add user lock for chat api, Prevent insufficient deduction of user power caused by submitting multiple requests at one time
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
@@ -95,7 +96,7 @@ func (s *Service) processNextTask() {
|
||||
|
||||
if err := s.ProcessTask(jobId); err != nil {
|
||||
logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
|
||||
s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
|
||||
s.UpdateJobStatus(jobId, types.JMTaskStatusFailed, err.Error())
|
||||
} else {
|
||||
logger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
|
||||
}
|
||||
@@ -120,7 +121,7 @@ func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.Jimeng
|
||||
ReqKey: req.ReqKey,
|
||||
Prompt: req.Prompt,
|
||||
Params: string(paramsJson),
|
||||
Status: model.JMTaskStatusInQueue,
|
||||
Status: types.JMTaskStatusInQueue,
|
||||
Power: req.Power,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
@@ -148,7 +149,7 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
}
|
||||
|
||||
// 更新任务状态为处理中
|
||||
if err := s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, ""); err != nil {
|
||||
if err := s.UpdateJobStatus(job.Id, types.JMTaskStatusGenerating, ""); err != nil {
|
||||
return fmt.Errorf("update job status failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -199,13 +200,13 @@ func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, er
|
||||
|
||||
// 根据任务类型设置特定参数
|
||||
switch job.Type {
|
||||
case model.JMTaskTypeImage:
|
||||
case types.JMTaskTypeImage:
|
||||
s.setTextToImageParams(req, params)
|
||||
case model.JMTaskTypeVideo:
|
||||
case types.JMTaskTypeVideo:
|
||||
s.setImageToImageParams(req, params)
|
||||
case model.JMTaskTypeVirtualHuman:
|
||||
case types.JMTaskTypeVirtualHuman:
|
||||
s.setImageEditParams(req, params)
|
||||
case model.JMTaskTypeActionTransfer:
|
||||
case types.JMTaskTypeActionTransfer:
|
||||
s.setImageEffectsParams(req, params)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
|
||||
@@ -353,7 +354,7 @@ func (s *Service) pollTaskStatus() {
|
||||
|
||||
for {
|
||||
var jobs []model.JimengJob
|
||||
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
|
||||
s.db.Where("status IN (?)", []types.JMTaskStatus{types.JMTaskStatusGenerating, types.JMTaskStatusInQueue}).Find(&jobs)
|
||||
if len(jobs) == 0 {
|
||||
logger.Debugf("no jimeng task to poll, sleep 10s")
|
||||
time.Sleep(10 * time.Second)
|
||||
@@ -389,7 +390,7 @@ func (s *Service) pollTaskStatus() {
|
||||
}
|
||||
|
||||
switch resp.Data.Status {
|
||||
case model.JMTaskStatusDone:
|
||||
case types.JMTaskStatusDone:
|
||||
// 判断任务是否成功
|
||||
if resp.Message != "Success" {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
|
||||
@@ -398,7 +399,7 @@ func (s *Service) pollTaskStatus() {
|
||||
|
||||
// 任务完成,更新结果
|
||||
updates := map[string]any{
|
||||
"status": model.JMTaskStatusSuccess,
|
||||
"status": types.JMTaskStatusSuccess,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
@@ -421,15 +422,15 @@ func (s *Service) pollTaskStatus() {
|
||||
}
|
||||
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
|
||||
case types.JMTaskStatusInQueue, types.JMTaskStatusGenerating:
|
||||
// 任务处理中
|
||||
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
|
||||
s.UpdateJobStatus(job.Id, types.JMTaskStatusGenerating, "")
|
||||
|
||||
case model.JMTaskStatusNotFound:
|
||||
case types.JMTaskStatusNotFound:
|
||||
// 任务未找到
|
||||
s.handleTaskError(job.Id, "task not found")
|
||||
|
||||
case model.JMTaskStatusExpired:
|
||||
case types.JMTaskStatusExpired:
|
||||
continue
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
@@ -444,7 +445,7 @@ func (s *Service) pollTaskStatus() {
|
||||
}
|
||||
|
||||
// UpdateJobStatus 更新任务状态
|
||||
func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg string) error {
|
||||
func (s *Service) UpdateJobStatus(jobId uint, status types.JMTaskStatus, errMsg string) error {
|
||||
updates := map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
@@ -458,7 +459,7 @@ func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg
|
||||
// handleTaskError 处理任务错误
|
||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||
logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
||||
return s.UpdateJobStatus(jobId, types.JMTaskStatusFailed, errMsg)
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列(用于手动重试)
|
||||
@@ -469,8 +470,8 @@ func (s *Service) PushTaskToQueue(jobId uint) error {
|
||||
// GetTaskStats 获取任务统计信息
|
||||
func (s *Service) GetTaskStats() (map[string]any, error) {
|
||||
type StatResult struct {
|
||||
Status string `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
Status types.JMTaskStatus `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
@@ -492,7 +493,7 @@ func (s *Service) GetTaskStats() (map[string]any, error) {
|
||||
|
||||
for _, stat := range stats {
|
||||
result["total"] = result["total"].(int64) + stat.Count
|
||||
result[stat.Status] = stat.Count
|
||||
result[string(stat.Status)] = stat.Count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
Reference in New Issue
Block a user