mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 10:13:44 +08:00
feat: image preview for stable-diffusion task is ready
This commit is contained in:
@@ -20,7 +20,7 @@ type ServicePool struct {
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
}
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, levelDB *store.LevelDB) *ServicePool {
|
||||
services := make([]*Service, 0)
|
||||
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
||||
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
||||
@@ -32,7 +32,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
|
||||
// create sd service
|
||||
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
|
||||
service := NewService(name, config, taskQueue, notifyQueue, db, manager)
|
||||
service := NewService(name, config, taskQueue, notifyQueue, db, manager, levelDB)
|
||||
// run sd service
|
||||
go func() {
|
||||
service.Run()
|
||||
|
||||
@@ -24,9 +24,10 @@ type Service struct {
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
name string // service name
|
||||
leveldb *store.LevelDB
|
||||
}
|
||||
|
||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
|
||||
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
||||
return &Service{
|
||||
name: name,
|
||||
@@ -35,6 +36,7 @@ func NewService(name string, config types.StableDiffusionConfig, taskQueue *stor
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
db: db,
|
||||
leveldb: levelDB,
|
||||
uploadManager: manager,
|
||||
}
|
||||
}
|
||||
@@ -167,15 +169,20 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
}
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress()
|
||||
// 更新任务进度
|
||||
if err == nil && resp.Progress > 0 {
|
||||
logger.Debugf("Check task progress: %+v", resp.Progress)
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
// 保存预览图片数据
|
||||
if resp.CurrentImage != "" {
|
||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user