mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-12-05 07:46:06 +08:00
重构异步任务更新方式,使用 Http 替代 websocket
This commit is contained in:
@@ -34,10 +34,8 @@ type Service struct {
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
userService *service.UserService
|
||||
wsService *service.WebsocketService
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
|
||||
@@ -45,11 +43,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +56,7 @@ func (s *Service) PushTask(task types.DallTask) {
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未提交的人物加载到队列
|
||||
// 将数据库中未提交的任务加载到队列
|
||||
var jobs []model.DallJob
|
||||
s.db.Where("progress", 0).Find(&jobs)
|
||||
for _, v := range jobs {
|
||||
@@ -84,16 +80,16 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
logger.Infof("handle a new DALL-E task: %+v", task)
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
}
|
||||
go func() {
|
||||
_, err = s.Image(task, false)
|
||||
if err != nil {
|
||||
logger.Errorf("error with image task: %v", err)
|
||||
s.db.Model(&model.DallJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -212,10 +208,9 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return "", fmt.Errorf("err with update database: %v", err)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
var content string
|
||||
if sync {
|
||||
imgURL, err := s.downloadImage(task.Id, int(task.UserId), res.Data[0].Url)
|
||||
imgURL, err := s.downloadImage(task.Id, res.Data[0].Url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
@@ -225,26 +220,6 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running DALL-E task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChDall, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskStatus() {
|
||||
go func() {
|
||||
logger.Info("Running DALL-E task status checking ...")
|
||||
@@ -254,7 +229,7 @@ func (s *Service) CheckTaskStatus() {
|
||||
s.db.Where("progress < ?", 100).Find(&jobs)
|
||||
for _, job := range jobs {
|
||||
// 超时的任务标记为失败
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
|
||||
if time.Since(job.CreatedAt) > time.Minute*10 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
@@ -301,7 +276,7 @@ func (s *Service) DownloadImages() {
|
||||
}
|
||||
|
||||
logger.Infof("try to download image: %s", v.OrgURL)
|
||||
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
|
||||
imgURL, err := s.downloadImage(v.Id, v.OrgURL)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: %s, error: %v", imgURL, err)
|
||||
continue
|
||||
@@ -316,7 +291,7 @@ func (s *Service) DownloadImages() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
|
||||
func (s *Service) downloadImage(jobId uint, orgURL string) (string, error) {
|
||||
// sava image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
|
||||
if err != nil {
|
||||
@@ -328,6 +303,5 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
|
||||
if res.Error != nil {
|
||||
return "", err
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
|
||||
return imgURL, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user