geekai/api/service/sd/sd_service.go

73 lines
1.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package sd
import (
"chatplus/core/types"
"chatplus/service/mj"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"context"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
"time"
)
// SD 绘画服务
const RunningJobKey = "StableDiffusion_Running_Job"
type Service struct {
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
Client *Client
}
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service {
return &Service{
redis: redisCli,
db: db,
Client: client,
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
}
}
func (s *Service) Run() {
logger.Info("Starting StableDiffusion job consumer.")
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, RunningJobKey).Result()
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask
err = s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
err = s.Client.Txt2Img(task.Params)
if err != nil {
logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 {
s.taskQueue.RPush(task)
}
task.RetryCount += 1
time.Sleep(time.Second * 3)
continue
}
// 更新任务的执行状态
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5)
}
}
func (s *Service) PushTask(task types.SdTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
}