mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 08:46:38 +08:00
306 lines
8.9 KiB
Go
306 lines
8.9 KiB
Go
package sd
|
|
|
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
// * Use of this source code is governed by a Apache-2.0 license
|
|
// * that can be found in the LICENSE file.
|
|
// * @Author yangjian102621@163.com
|
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
|
|
import (
|
|
"fmt"
|
|
"geekai/core/types"
|
|
logger2 "geekai/logger"
|
|
"geekai/service"
|
|
"geekai/service/oss"
|
|
"geekai/store"
|
|
"geekai/store/model"
|
|
"geekai/utils"
|
|
"github.com/go-redis/redis/v8"
|
|
"time"
|
|
|
|
"github.com/imroc/req/v3"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var logger = logger2.GetLogger()
|
|
|
|
// SD 绘画服务
|
|
|
|
type Service struct {
|
|
httpClient *req.Client
|
|
taskQueue *store.RedisQueue
|
|
notifyQueue *store.RedisQueue
|
|
db *gorm.DB
|
|
uploadManager *oss.UploaderManager
|
|
leveldb *store.LevelDB
|
|
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
|
}
|
|
|
|
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
|
|
return &Service{
|
|
httpClient: req.C(),
|
|
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
|
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
|
db: db,
|
|
leveldb: levelDB,
|
|
Clients: types.NewLMap[uint, *types.WsClient](),
|
|
uploadManager: manager,
|
|
}
|
|
}
|
|
|
|
func (s *Service) Run() {
|
|
logger.Infof("Starting Stable-Diffusion job consumer")
|
|
go func() {
|
|
for {
|
|
var task types.SdTask
|
|
err := s.taskQueue.LPop(&task)
|
|
if err != nil {
|
|
logger.Errorf("taking task with error: %v", err)
|
|
continue
|
|
}
|
|
|
|
// translate prompt
|
|
if utils.HasChinese(task.Params.Prompt) {
|
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
|
if err == nil {
|
|
task.Params.Prompt = content
|
|
} else {
|
|
logger.Warnf("error with translate prompt: %v", err)
|
|
}
|
|
}
|
|
|
|
// translate negative prompt
|
|
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
|
if err == nil {
|
|
task.Params.NegPrompt = content
|
|
} else {
|
|
logger.Warnf("error with translate prompt: %v", err)
|
|
}
|
|
}
|
|
|
|
logger.Infof("handle a new Stable-Diffusion task: %+v", task)
|
|
err = s.Txt2Img(task)
|
|
if err != nil {
|
|
logger.Error("绘画任务执行失败:", err.Error())
|
|
// update the task progress
|
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
|
"progress": service.FailTaskProgress,
|
|
"err_msg": err.Error(),
|
|
})
|
|
// 通知前端,任务失败
|
|
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
|
continue
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Txt2ImgReq 文生图请求实体
|
|
type Txt2ImgReq struct {
|
|
Prompt string `json:"prompt"`
|
|
NegativePrompt string `json:"negative_prompt"`
|
|
Seed int64 `json:"seed,omitempty"`
|
|
Steps int `json:"steps"`
|
|
CfgScale float32 `json:"cfg_scale"`
|
|
Width int `json:"width"`
|
|
Height int `json:"height"`
|
|
SamplerName string `json:"sampler_name"`
|
|
Scheduler string `json:"scheduler"`
|
|
EnableHr bool `json:"enable_hr,omitempty"`
|
|
HrScale int `json:"hr_scale,omitempty"`
|
|
HrUpscaler string `json:"hr_upscaler,omitempty"`
|
|
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
|
|
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
|
|
ForceTaskId string `json:"force_task_id,omitempty"`
|
|
}
|
|
|
|
// Txt2ImgResp 文生图响应实体
|
|
type Txt2ImgResp struct {
|
|
Images []string `json:"images"`
|
|
Parameters struct {
|
|
} `json:"parameters"`
|
|
Info string `json:"info"`
|
|
}
|
|
|
|
// TaskProgressResp 任务进度响应实体
|
|
type TaskProgressResp struct {
|
|
Progress float64 `json:"progress"`
|
|
EtaRelative float64 `json:"eta_relative"`
|
|
CurrentImage string `json:"current_image"`
|
|
}
|
|
|
|
// Txt2Img 文生图 API
|
|
func (s *Service) Txt2Img(task types.SdTask) error {
|
|
body := Txt2ImgReq{
|
|
Prompt: task.Params.Prompt,
|
|
NegativePrompt: task.Params.NegPrompt,
|
|
Steps: task.Params.Steps,
|
|
CfgScale: task.Params.CfgScale,
|
|
Width: task.Params.Width,
|
|
Height: task.Params.Height,
|
|
SamplerName: task.Params.Sampler,
|
|
Scheduler: task.Params.Scheduler,
|
|
ForceTaskId: task.Params.TaskId,
|
|
}
|
|
if task.Params.Seed > 0 {
|
|
body.Seed = task.Params.Seed
|
|
}
|
|
if task.Params.HdFix {
|
|
body.EnableHr = true
|
|
body.HrScale = task.Params.HdScale
|
|
body.HrUpscaler = task.Params.HdScaleAlg
|
|
body.HrSecondPassSteps = task.Params.HdSteps
|
|
body.DenoisingStrength = task.Params.HdRedrawRate
|
|
}
|
|
var res Txt2ImgResp
|
|
var errChan = make(chan error)
|
|
|
|
var apiKey model.ApiKey
|
|
err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
|
|
if err != nil {
|
|
return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
|
|
}
|
|
|
|
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
|
|
logger.Debugf("send image request to %s", apiURL)
|
|
// send a request to sd api endpoint
|
|
go func() {
|
|
response, err := s.httpClient.R().
|
|
SetHeader("Authorization", apiKey.Value).
|
|
SetBody(body).
|
|
SetSuccessResult(&res).
|
|
Post(apiURL)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
if response.IsErrorState() {
|
|
errChan <- fmt.Errorf("error http code status: %v", response.Status)
|
|
return
|
|
}
|
|
|
|
// update the last used time
|
|
apiKey.LastUsedAt = time.Now().Unix()
|
|
s.db.Updates(&apiKey)
|
|
|
|
// 保存 Base64 图片
|
|
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error with upload image: %v", err)
|
|
return
|
|
}
|
|
// 获取绘画真实的 seed
|
|
var info map[string]interface{}
|
|
err = utils.JsonDecode(res.Info, &info)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error with decode task response: %v", err)
|
|
return
|
|
}
|
|
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
|
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params), Prompt: task.Params.Prompt})
|
|
errChan <- nil
|
|
}()
|
|
|
|
// waiting for task finish
|
|
for {
|
|
select {
|
|
case err := <-errChan:
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// task finished
|
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
|
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
|
// 从 leveldb 中删除预览图片数据
|
|
_ = s.leveldb.Delete(task.Params.TaskId)
|
|
return nil
|
|
default:
|
|
err, resp := s.checkTaskProgress(apiKey)
|
|
// 更新任务进度
|
|
if err == nil && resp.Progress > 0 {
|
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
|
// 发送更新状态信号
|
|
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
|
// 保存预览图片数据
|
|
if resp.CurrentImage != "" {
|
|
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
|
}
|
|
}
|
|
time.Sleep(time.Second)
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
// 执行任务
|
|
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
|
|
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
|
|
var res TaskProgressResp
|
|
response, err := s.httpClient.R().
|
|
SetHeader("Authorization", apiKey.Value).
|
|
SetSuccessResult(&res).
|
|
Get(apiURL)
|
|
if err != nil {
|
|
return err, nil
|
|
}
|
|
if response.IsErrorState() {
|
|
return fmt.Errorf("error http code status: %v", response.Status), nil
|
|
}
|
|
|
|
return nil, &res
|
|
}
|
|
|
|
func (s *Service) PushTask(task types.SdTask) {
|
|
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
|
s.taskQueue.RPush(task)
|
|
}
|
|
|
|
func (s *Service) CheckTaskNotify() {
|
|
go func() {
|
|
logger.Info("Running Stable-Diffusion task notify checking ...")
|
|
for {
|
|
var message service.NotifyMessage
|
|
err := s.notifyQueue.LPop(&message)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
client := s.Clients.Get(uint(message.UserId))
|
|
if client == nil {
|
|
continue
|
|
}
|
|
err = client.Send([]byte(message.Message))
|
|
if err != nil {
|
|
continue
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
|
func (s *Service) CheckTaskStatus() {
|
|
go func() {
|
|
logger.Info("Running Stable-Diffusion task status checking ...")
|
|
for {
|
|
var jobs []model.SdJob
|
|
res := s.db.Where("progress < ?", 100).Find(&jobs)
|
|
if res.Error != nil {
|
|
time.Sleep(5 * time.Second)
|
|
continue
|
|
}
|
|
|
|
for _, job := range jobs {
|
|
// 5 分钟还没完成的任务标记为失败
|
|
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
|
|
job.Progress = service.FailTaskProgress
|
|
job.ErrMsg = "任务超时"
|
|
s.db.Updates(&job)
|
|
}
|
|
}
|
|
time.Sleep(time.Second * 5)
|
|
}
|
|
}()
|
|
}
|