mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 02:33:42 +08:00
refactor: refactor stable diffusion service, add service pool support
This commit is contained in:
52
api/service/sd/pool.go
Normal file
52
api/service/sd/pool.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package sd
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ServicePool struct {
|
||||
services []*Service
|
||||
taskQueue *store.RedisQueue
|
||||
}
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||
services := make([]*Service, 0)
|
||||
queue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
||||
// create mj client and service
|
||||
for k, config := range appConfig.SdConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
|
||||
// create sd service
|
||||
name := fmt.Sprintf("StableDifffusion Service-%d", k)
|
||||
service := NewService(name, 4, 600, &config, queue, db, manager)
|
||||
// run sd service
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
|
||||
services = append(services, service)
|
||||
}
|
||||
|
||||
return &ServicePool{
|
||||
taskQueue: queue,
|
||||
services: services,
|
||||
}
|
||||
}
|
||||
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (p *ServicePool) PushTask(task types.SdTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
p.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
// HasAvailableService check if it has available mj service in pool
|
||||
func (p *ServicePool) HasAvailableService() bool {
|
||||
return len(p.services) > 0
|
||||
}
|
||||
@@ -5,84 +5,96 @@ import (
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
|
||||
const RunningJobKey = "StableDiffusion_Running_Job"
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
config *types.StableDiffusionConfig
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
Clients *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池
|
||||
httpClient *req.Client
|
||||
config *types.StableDiffusionConfig
|
||||
taskQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
name string // service name
|
||||
maxHandleTaskNum int32 // max task number current service can handle
|
||||
handledTaskNum int32 // already handled task number
|
||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||
taskTimeout int64
|
||||
}
|
||||
|
||||
func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||
func NewService(name string, maxTaskNum int32, timeout int64, config *types.StableDiffusionConfig, queue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||
return &Service{
|
||||
config: &config.SdConfig,
|
||||
httpClient: req.C(),
|
||||
redis: redisCli,
|
||||
db: db,
|
||||
uploadManager: manager,
|
||||
Clients: types.NewLMap[string, *types.WsClient](),
|
||||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
|
||||
name: name,
|
||||
config: config,
|
||||
httpClient: req.C(),
|
||||
taskQueue: queue,
|
||||
db: db,
|
||||
uploadManager: manager,
|
||||
taskTimeout: timeout,
|
||||
maxHandleTaskNum: maxTaskNum,
|
||||
taskStartTimes: make(map[int]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Info("Starting StableDiffusion job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||||
if err == nil { // 队列串行执行
|
||||
s.checkTasks()
|
||||
if !s.canHandleTask() {
|
||||
// current service is full, can not handle more task
|
||||
// waiting for running task finish
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
var task types.SdTask
|
||||
err = s.taskQueue.LPop(&task)
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
||||
err = s.Txt2Img(task)
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
if task.RetryCount <= 5 {
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
task.RetryCount += 1
|
||||
time.Sleep(time.Second * 3)
|
||||
// update the task progress
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
// release task num
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务的执行状态
|
||||
s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[task.Id] = time.Now()
|
||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// PushTask 推送任务到队列
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
logger.Infof("add a new Stable Diffusion Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
// check if current service instance can handle more task
|
||||
func (s *Service) canHandleTask() bool {
|
||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
||||
return handledNum < s.maxHandleTaskNum
|
||||
}
|
||||
|
||||
// remove the expired tasks
|
||||
func (s *Service) checkTasks() {
|
||||
for k, t := range s.taskStartTimes {
|
||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||
delete(s.taskStartTimes, k)
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
// delete task from database
|
||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Txt2Img 文生图 API
|
||||
@@ -237,9 +249,8 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
|
||||
}
|
||||
|
||||
func (s *Service) callback(data CBReq) {
|
||||
// 释放任务锁
|
||||
s.redis.Del(context.Background(), RunningJobKey)
|
||||
client := s.Clients.Get(data.SessionId)
|
||||
// release task num
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
if data.Success { // 任务成功
|
||||
var job model.SdJob
|
||||
res := s.db.Where("id = ?", data.JobId).First(&job)
|
||||
@@ -259,13 +270,15 @@ func (s *Service) callback(data CBReq) {
|
||||
|
||||
params.Seed = data.Seed
|
||||
if data.ImageName != "" { // 下载图片
|
||||
imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
|
||||
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
|
||||
if data.Progress == 100 {
|
||||
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
}
|
||||
job.ImgURL = imageURL
|
||||
}
|
||||
job.ImgURL = imageURL
|
||||
}
|
||||
|
||||
job.Params = utils.JsonEncode(params)
|
||||
@@ -275,38 +288,16 @@ func (s *Service) callback(data CBReq) {
|
||||
return
|
||||
}
|
||||
|
||||
var jobVo vo.SdJob
|
||||
err = utils.CopyObject(job, &jobVo)
|
||||
if err != nil {
|
||||
logger.Error("error with copy object: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Progress < 100 && data.ImageData != "" {
|
||||
jobVo.ImgURL = data.ImageData
|
||||
}
|
||||
|
||||
logger.Infof("绘图进度:%d", data.Progress)
|
||||
logger.Debugf("绘图进度:%d", data.Progress)
|
||||
|
||||
// 扣减绘图次数
|
||||
if data.Progress == 100 {
|
||||
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", jobVo.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
}
|
||||
// 推送任务到前端
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
}
|
||||
|
||||
} else { // 任务失败
|
||||
logger.Error("任务执行失败:", data.Message)
|
||||
// 删除任务
|
||||
s.db.Delete(&model.SdJob{Id: uint(data.JobId)})
|
||||
// 推送消息到前端
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, vo.SdJob{
|
||||
Id: uint(data.JobId),
|
||||
Progress: -1,
|
||||
TaskId: data.TaskId,
|
||||
})
|
||||
}
|
||||
// update the task progress
|
||||
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumn("progress", -1)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user