diff --git a/api/service/sd/service.go b/api/service/sd/service.go index cef9fe60..1741028f 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -136,6 +136,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { taskInfo.TaskId = params.TaskId taskInfo.Data = data taskInfo.JobId = task.Id + taskInfo.UserId = uint(task.UserId) go func() { s.runTask(taskInfo, s.httpClient) }() @@ -158,7 +159,7 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { Duration float64 `json:"duration"` AverageDuration float64 `json:"average_duration"` } - var cbReq = CBReq{TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} + var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict") if err != nil { cbReq.Message = "error with send request: " + err.Error() @@ -231,7 +232,7 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { TextInfo interface{} `json:"textinfo"` } response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress") - var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} + var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} if err != nil { // TODO: 这里可以考虑设置失败重试次数 logger.Error(err) return @@ -292,15 +293,11 @@ func (s *Service) callback(data CBReq) { } logger.Debugf("绘图进度:%d", data.Progress) - - // 扣减绘图次数 - if data.Progress == 100 { - s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) - } - } else { // 任务失败 logger.Error("任务执行失败:", data.Message) // update the task progress s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumn("progress", -1) + // restore img_calls + s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", data.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) } } diff --git a/api/service/sd/types.go b/api/service/sd/types.go index b3a949c6..56ebb5bd 100644 --- a/api/service/sd/types.go +++ b/api/service/sd/types.go @@ -5,6 +5,7 @@ import logger2 "chatplus/logger" var logger = logger2.GetLogger() type TaskInfo struct { + UserId uint `json:"user_id"` SessionId string `json:"session_id"` JobId int `json:"job_id"` TaskId string `json:"task_id"` @@ -15,6 +16,7 @@ type TaskInfo struct { } type CBReq struct { + UserId uint SessionId string JobId int TaskId string