mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
96 lines
2.3 KiB
Go
96 lines
2.3 KiB
Go
package service
|
||
|
||
import (
|
||
"chatplus/core/types"
|
||
"chatplus/store"
|
||
"chatplus/store/model"
|
||
"chatplus/utils"
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/go-redis/redis/v8"
|
||
"github.com/imroc/req/v3"
|
||
"gorm.io/gorm"
|
||
"time"
|
||
)
|
||
|
||
// SD 绘画服务
|
||
|
||
const SdRunningJobKey = "StableDiffusion_Running_Job"
|
||
|
||
type SdService struct {
|
||
config types.ChatPlusExtConfig
|
||
client *req.Client
|
||
taskQueue *store.RedisQueue
|
||
redis *redis.Client
|
||
db *gorm.DB
|
||
}
|
||
|
||
func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService {
|
||
return &SdService{
|
||
config: appConfig.ExtConfig,
|
||
redis: client,
|
||
db: db,
|
||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client),
|
||
client: req.C().SetTimeout(30 * time.Second)}
|
||
}
|
||
|
||
func (s *SdService) Run() {
|
||
logger.Info("Starting StableDiffusion job consumer.")
|
||
ctx := context.Background()
|
||
for {
|
||
_, err := s.redis.Get(ctx, SdRunningJobKey).Result()
|
||
if err == nil { // 队列串行执行
|
||
time.Sleep(time.Second * 3)
|
||
continue
|
||
}
|
||
var task types.SdTask
|
||
err = s.taskQueue.LPop(&task)
|
||
if err != nil {
|
||
logger.Errorf("taking task with error: %v", err)
|
||
continue
|
||
}
|
||
logger.Infof("Consuming Task: %+v", task)
|
||
err = s.txt2img(task.Params)
|
||
if err != nil {
|
||
logger.Error("绘画任务执行失败:", err)
|
||
if task.RetryCount <= 5 {
|
||
s.taskQueue.RPush(task)
|
||
}
|
||
task.RetryCount += 1
|
||
time.Sleep(time.Second * 3)
|
||
continue
|
||
}
|
||
|
||
// 更新任务的执行状态
|
||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||
}
|
||
}
|
||
|
||
func (s *SdService) PushTask(task types.SdTask) {
|
||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||
s.taskQueue.RPush(task)
|
||
}
|
||
|
||
func (s *SdService) txt2img(params types.SdParams) error {
|
||
logger.Infof("SD 绘画参数:%+v", params)
|
||
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
|
||
var res types.BizVo
|
||
r, err := s.client.R().
|
||
SetHeader("Authorization", s.config.Token).
|
||
SetHeader("Content-Type", "application/json").
|
||
SetBody(params).
|
||
SetSuccessResult(&res).Post(url)
|
||
if err != nil || r.IsErrorState() {
|
||
return fmt.Errorf("%v%v", r.String(), err)
|
||
}
|
||
|
||
if res.Code != types.Success {
|
||
return errors.New(res.Message)
|
||
}
|
||
|
||
return nil
|
||
}
|