feat: stable-diffusion refactored, replace websocket api with sdapi

This commit is contained in:
RockYang
2024-03-26 18:23:08 +08:00
parent 870706c4ff
commit b60a639312
11 changed files with 191 additions and 335 deletions

View File

@@ -25,14 +25,14 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.SdConfigs {
for _, config := range appConfig.SdConfigs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%d", k)
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
service := NewService(name, config, taskQueue, notifyQueue, db, manager)
// run sd service
go func() {
service.Run()
@@ -58,6 +58,7 @@ func (p *ServicePool) PushTask(task types.SdTask) {
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
@@ -79,6 +80,7 @@ func (p *ServicePool) CheckTaskNotify() {
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)