mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 11:13:42 +08:00
opt: enable use cdn url for mj-plus
This commit is contained in:
@@ -86,10 +86,10 @@ func (s *Service) Run() {
|
||||
}
|
||||
|
||||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||
errMsg := err.Error() + res.Description
|
||||
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
|
||||
logger.Error("绘画任务执行失败:", errMsg)
|
||||
// update the task progress
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
||||
s.db.Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": errMsg,
|
||||
})
|
||||
@@ -105,10 +105,10 @@ func (s *Service) Run() {
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[task.Id] = time.Now()
|
||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
||||
atomic.AddInt32(&s.HandledTaskNum, 1)
|
||||
// 更新任务 ID/频道
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{
|
||||
s.db.Debug().Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": res.Result,
|
||||
"channel_id": s.Name,
|
||||
})
|
||||
@@ -152,26 +152,55 @@ type CBReq struct {
|
||||
} `json:"properties"`
|
||||
}
|
||||
|
||||
func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error {
|
||||
func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||
task, err := s.Client.QueryTask(job.TaskId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
job.Progress = utils.IntValue(strings.Replace(data.Progress, "%", "", 1), 0)
|
||||
job.Prompt = data.Properties.FinalPrompt
|
||||
if data.ImageUrl != "" {
|
||||
job.OrgURL = data.ImageUrl
|
||||
// 任务执行失败了
|
||||
if task.FailReason != "" {
|
||||
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
return fmt.Errorf("task failed: %v", task.FailReason)
|
||||
}
|
||||
|
||||
if len(task.Buttons) > 0 {
|
||||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
oldProgress := job.Progress
|
||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||
job.Prompt = task.PromptEn
|
||||
if task.ImageUrl != "" {
|
||||
if s.Client.Config.CdnURL != "" {
|
||||
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
|
||||
} else {
|
||||
job.OrgURL = task.ImageUrl
|
||||
}
|
||||
}
|
||||
job.UseProxy = true
|
||||
job.MessageId = data.Id
|
||||
logger.Debugf("JOB: %+v", job)
|
||||
res := s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with update job: %v", res.Error)
|
||||
job.MessageId = task.Id
|
||||
tx := s.db.Updates(&job)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("error with update database: %v", tx.Error)
|
||||
}
|
||||
|
||||
if data.Status == "SUCCESS" {
|
||||
if task.Status == "SUCCESS" {
|
||||
// release lock task
|
||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != job.Progress {
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetImageHash(action string) string {
|
||||
split := strings.Split(action, "::")
|
||||
if len(split) > 5 {
|
||||
return split[4]
|
||||
}
|
||||
return split[len(split)-1]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user