feat: image preview for stable-diffusion task is ready

This commit is contained in:
RockYang
2024-04-02 17:24:38 +08:00
parent da14309ef9
commit 1cff4b63cd
13 changed files with 171 additions and 13 deletions

View File

@@ -20,7 +20,7 @@ type ServicePool struct {
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, levelDB *store.LevelDB) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
@@ -32,7 +32,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
service := NewService(name, config, taskQueue, notifyQueue, db, manager)
service := NewService(name, config, taskQueue, notifyQueue, db, manager, levelDB)
// run sd service
go func() {
service.Run()

View File

@@ -24,9 +24,10 @@ type Service struct {
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
leveldb *store.LevelDB
}
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
return &Service{
name: name,
@@ -35,6 +36,7 @@ func NewService(name string, config types.StableDiffusionConfig, taskQueue *stor
taskQueue: taskQueue,
notifyQueue: notifyQueue,
db: db,
leveldb: levelDB,
uploadManager: manager,
}
}
@@ -167,15 +169,20 @@ func (s *Service) Txt2Img(task types.SdTask) error {
}
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(task.UserId)
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
default:
err, resp := s.checkTaskProgress()
// 更新任务进度
if err == nil && resp.Progress > 0 {
logger.Debugf("Check task progress: %+v", resp.Progress)
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(task.UserId)
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
}
}
time.Sleep(time.Second)
}