mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
73 lines
1.7 KiB
Go
73 lines
1.7 KiB
Go
package sd
|
||
|
||
import (
|
||
"chatplus/core/types"
|
||
"chatplus/service/mj"
|
||
"chatplus/store"
|
||
"chatplus/store/model"
|
||
"chatplus/utils"
|
||
"context"
|
||
"github.com/go-redis/redis/v8"
|
||
"gorm.io/gorm"
|
||
"time"
|
||
)
|
||
|
||
// SD 绘画服务
|
||
|
||
const RunningJobKey = "StableDiffusion_Running_Job"
|
||
|
||
type Service struct {
|
||
taskQueue *store.RedisQueue
|
||
redis *redis.Client
|
||
db *gorm.DB
|
||
Client *Client
|
||
}
|
||
|
||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service {
|
||
return &Service{
|
||
redis: redisCli,
|
||
db: db,
|
||
Client: client,
|
||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
|
||
}
|
||
}
|
||
|
||
func (s *Service) Run() {
|
||
logger.Info("Starting StableDiffusion job consumer.")
|
||
ctx := context.Background()
|
||
for {
|
||
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||
if err == nil { // 队列串行执行
|
||
time.Sleep(time.Second * 3)
|
||
continue
|
||
}
|
||
var task types.SdTask
|
||
err = s.taskQueue.LPop(&task)
|
||
if err != nil {
|
||
logger.Errorf("taking task with error: %v", err)
|
||
continue
|
||
}
|
||
logger.Infof("Consuming Task: %+v", task)
|
||
err = s.Client.Txt2Img(task.Params)
|
||
if err != nil {
|
||
logger.Error("绘画任务执行失败:", err)
|
||
if task.RetryCount <= 5 {
|
||
s.taskQueue.RPush(task)
|
||
}
|
||
task.RetryCount += 1
|
||
time.Sleep(time.Second * 3)
|
||
continue
|
||
}
|
||
|
||
// 更新任务的执行状态
|
||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||
s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||
}
|
||
}
|
||
|
||
func (s *Service) PushTask(task types.SdTask) {
|
||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||
s.taskQueue.RPush(task)
|
||
}
|