feat: add implements for stable diffusion service

This commit is contained in:
RockYang
2023-09-26 18:16:51 +08:00
parent c1143d7a6d
commit d51a724ade
8 changed files with 569 additions and 62 deletions

View File

@@ -21,41 +21,6 @@ var logger = logger2.GetLogger()
const MjRunningJobKey = "MidJourney_Running_Job"
type TaskType string
func (t TaskType) String() string {
return string(t)
}
const (
Image = TaskType("image")
Upscale = TaskType("upscale")
Variation = TaskType("variation")
)
type TaskSrc string
const (
TaskSrcChat = TaskSrc("chat")
TaskSrcImg = TaskSrc("img")
)
type MjTask struct {
Id int `json:"id"`
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
ChatId string `json:"chat_id,omitempty"`
RoleId int `json:"role_id,omitempty"`
Icon string `json:"icon,omitempty"`
Index int32 `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
RetryCount int `json:"retry_count"`
}
type MjService struct {
config types.ChatPlusExtConfig
client *req.Client
@@ -78,11 +43,11 @@ func (s *MjService) Run() {
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
if err == nil {
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task MjTask
var task types.MjTask
err = s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
@@ -90,17 +55,17 @@ func (s *MjService) Run() {
}
logger.Infof("Consuming Task: %+v", task)
switch task.Type {
case Image:
case types.TaskImage:
err = s.image(task.Prompt)
break
case Upscale:
case types.TaskUpscale:
err = s.upscale(MjUpscaleReq{
Index: task.Index,
MessageId: task.MessageId,
MessageHash: task.MessageHash,
})
break
case Variation:
case types.TaskVariation:
err = s.variation(MjVariationReq{
Index: task.Index,
MessageId: task.MessageId,
@@ -124,7 +89,7 @@ func (s *MjService) Run() {
}
}
func (s *MjService) PushTask(task MjTask) {
func (s *MjService) PushTask(task types.MjTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
}

95
api/service/sd_service.go Normal file
View File

@@ -0,0 +1,95 @@
package service
import (
"chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"context"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"time"
)
// SD 绘画服务
const SdRunningJobKey = "StableDiffusion_Running_Job"
type SdService struct {
config types.ChatPlusExtConfig
client *req.Client
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
}
func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService {
return &SdService{
config: appConfig.ExtConfig,
redis: client,
db: db,
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client),
client: req.C().SetTimeout(30 * time.Second)}
}
func (s *SdService) Run() {
logger.Info("Starting StableDiffusion job consumer.")
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, SdRunningJobKey).Result()
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask
err = s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
err = s.txt2img(task.Params)
if err != nil {
logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 {
s.taskQueue.RPush(task)
}
task.RetryCount += 1
time.Sleep(time.Second * 3)
continue
}
// 更新任务的执行状态
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
}
}
func (s *SdService) PushTask(task types.SdTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
}
func (s *SdService) txt2img(params types.SdParams) error {
logger.Infof("SD 绘画参数:%+v", params)
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("Authorization", s.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(params).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return errors.New(res.Message)
}
return nil
}