mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: image preview for stable-diffusion task is ready
This commit is contained in:
		@@ -20,7 +20,7 @@ type ServicePool struct {
 | 
			
		||||
	Clients     *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
 | 
			
		||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, levelDB *store.LevelDB) *ServicePool {
 | 
			
		||||
	services := make([]*Service, 0)
 | 
			
		||||
	taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
 | 
			
		||||
	notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
 | 
			
		||||
@@ -32,7 +32,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
 | 
			
		||||
 | 
			
		||||
		// create sd service
 | 
			
		||||
		name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
 | 
			
		||||
		service := NewService(name, config, taskQueue, notifyQueue, db, manager)
 | 
			
		||||
		service := NewService(name, config, taskQueue, notifyQueue, db, manager, levelDB)
 | 
			
		||||
		// run sd service
 | 
			
		||||
		go func() {
 | 
			
		||||
			service.Run()
 | 
			
		||||
 
 | 
			
		||||
@@ -24,9 +24,10 @@ type Service struct {
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	name          string // service name
 | 
			
		||||
	leveldb       *store.LevelDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
 | 
			
		||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
 | 
			
		||||
	config.ApiURL = strings.TrimRight(config.ApiURL, "/")
 | 
			
		||||
	return &Service{
 | 
			
		||||
		name:          name,
 | 
			
		||||
@@ -35,6 +36,7 @@ func NewService(name string, config types.StableDiffusionConfig, taskQueue *stor
 | 
			
		||||
		taskQueue:     taskQueue,
 | 
			
		||||
		notifyQueue:   notifyQueue,
 | 
			
		||||
		db:            db,
 | 
			
		||||
		leveldb:       levelDB,
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -167,15 +169,20 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
			}
 | 
			
		||||
			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
 | 
			
		||||
			s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
			// 从 leveldb 中删除预览图片数据
 | 
			
		||||
			_ = s.leveldb.Delete(task.Params.TaskId)
 | 
			
		||||
			return nil
 | 
			
		||||
		default:
 | 
			
		||||
			err, resp := s.checkTaskProgress()
 | 
			
		||||
			// 更新任务进度
 | 
			
		||||
			if err == nil && resp.Progress > 0 {
 | 
			
		||||
				logger.Debugf("Check task progress: %+v", resp.Progress)
 | 
			
		||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
 | 
			
		||||
				// 发送更新状态信号
 | 
			
		||||
				s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
				// 保存预览图片数据
 | 
			
		||||
				if resp.CurrentImage != "" {
 | 
			
		||||
					_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user