feat: stable-diffusion refactored, replace websocket api with sdapi

This commit is contained in:
RockYang 2024-03-26 18:23:08 +08:00
parent 870706c4ff
commit b60a639312
11 changed files with 191 additions and 335 deletions

View File

@ -56,10 +56,10 @@ type MidJourneyConfig struct {
}
type StableDiffusionConfig struct {
Enabled bool
ApiURL string
ApiKey string
Txt2ImgJsonPath string
Enabled bool
Model string // 模型名称
ApiURL string
ApiKey string
}
type MidJourneyPlusConfig struct {

View File

@ -175,6 +175,12 @@ func main() {
// Stable Diffusion 机器人
fx.Provide(sd.NewServicePool),
fx.Invoke(func(pool *sd.ServicePool) {
if pool.HasAvailableService() {
pool.CheckTaskNotify()
pool.CheckTaskStatus()
}
}),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),

View File

@ -1,80 +0,0 @@
{
"data": [
"task(cxvkpawy8onnfti)",
"a cute girl",
"",
[],
20,
"DPM++ 2M Karras",
1,
1,
7,
512,
512,
false,
0.7,
2,
"Latent",
0,
0,
0,
"Use same checkpoint",
"Use same sampler",
"",
"",
[],
"None",
false,
"",
0.8,
-1,
false,
-1,
0,
0,
0,
null,
null,
null,
null,
false,
false,
"positive",
"comma",
0,
false,
false,
"",
"Seed",
"",
[],
"Nothing",
"",
[],
"Nothing",
"",
[],
true,
false,
false,
false,
0,
null,
null,
false,
null,
null,
false,
null,
null,
false,
50,
[],
"",
"",
""
],
"event_data": null,
"fn_index": 446,
"session_hash": "nk5noh1rz1o"
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"fmt"
"net/url"
"path/filepath"
@ -101,6 +102,20 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
}
func (s AliYunOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {

View File

@ -3,13 +3,13 @@ package oss
import (
"chatplus/core/types"
"chatplus/utils"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
)
type LocalStorage struct {
@ -73,6 +73,20 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
err = os.WriteFile(filePath, imageData, 0644)
if err != nil {
return "", fmt.Errorf("error writing to file:%v", err)
}
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) Delete(fileURL string) error {
if _, err := os.Stat(fileURL); err == nil {
return os.Remove(fileURL)

View File

@ -4,6 +4,7 @@ import (
"chatplus/core/types"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"net/url"
"path/filepath"
@ -96,6 +97,25 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
}, nil
}
func (s MiniOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
objectKey,
strings.NewReader(string(imageData)),
int64(len(imageData)),
minio.PutObjectOptions{ContentType: "image/png"})
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
}
func (s MiniOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {

View File

@ -5,6 +5,7 @@ import (
"chatplus/core/types"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"net/url"
"path/filepath"
@ -112,6 +113,22 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
if err != nil {
return "", err
}
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {

View File

@ -17,5 +17,6 @@ type File struct {
type Uploader interface {
PutFile(ctx *gin.Context, name string) (File, error)
PutImg(imageURL string, useProxy bool) (string, error)
PutBase64(imageData string) (string, error)
Delete(fileURL string) error
}

View File

@ -25,14 +25,14 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.SdConfigs {
for _, config := range appConfig.SdConfigs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%d", k)
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
service := NewService(name, config, taskQueue, notifyQueue, db, manager)
// run sd service
go func() {
service.Run()
@ -58,6 +58,7 @@ func (p *ServicePool) PushTask(task types.SdTask) {
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
@ -79,6 +80,7 @@ func (p *ServicePool) CheckTaskNotify() {
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)

View File

@ -6,59 +6,40 @@ import (
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"encoding/json"
"fmt"
"io"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"strings"
"time"
)
// SD 绘画服务
type Service struct {
httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
}
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
return &Service{
name: name,
config: config,
httpClient: req.C(),
taskQueue: taskQueue,
notifyQueue: notifyQueue,
db: db,
uploadManager: manager,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time),
name: name,
config: config,
httpClient: req.C(),
taskQueue: taskQueue,
notifyQueue: notifyQueue,
db: db,
uploadManager: manager,
}
}
func (s *Service) Run() {
for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask
err := s.taskQueue.LPop(&task)
if err != nil {
@ -74,239 +55,135 @@ func (s *Service) Run() {
"progress": -1,
"err_msg": err.Error(),
})
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
// 通知前端,任务失败
s.notifyQueue.RPush(task.UserId)
continue
}
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
}
}
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
// Txt2ImgReq 文生图请求实体
type Txt2ImgReq struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt"`
Seed int64 `json:"seed"`
Steps int `json:"steps"`
CfgScale float32 `json:"cfg_scale"`
Width int `json:"width"`
Height int `json:"height"`
SamplerName string `json:"sampler_name"`
EnableHr bool `json:"enable_hr,omitempty"`
HrScale int `json:"hr_scale,omitempty"`
HrUpscaler string `json:"hr_upscaler,omitempty"`
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
ForceTaskId string `json:"force_task_id,omitempty"`
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
// Txt2ImgResp 文生图响应实体
type Txt2ImgResp struct {
Images []string `json:"images"`
Parameters struct {
} `json:"parameters"`
Info string `json:"info"`
}
// TaskProgressResp 任务进度响应实体
type TaskProgressResp struct {
Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
}
// Txt2Img 文生图 API
func (s *Service) Txt2Img(task types.SdTask) error {
var taskInfo TaskInfo
bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath)
if err != nil {
return fmt.Errorf("error with load text2img json template file: %s", err.Error())
body := Txt2ImgReq{
Prompt: task.Params.Prompt,
NegativePrompt: task.Params.NegativePrompt,
Steps: task.Params.Steps,
CfgScale: task.Params.CfgScale,
Width: task.Params.Width,
Height: task.Params.Height,
SamplerName: task.Params.Sampler,
}
err = json.Unmarshal(bytes, &taskInfo)
if err != nil {
return fmt.Errorf("error with decode json params: %s", err.Error())
if task.Params.Seed > 0 {
body.Seed = task.Params.Seed
}
data := taskInfo.Data
params := task.Params
data[ParamKeys["task_id"]] = params.TaskId
data[ParamKeys["prompt"]] = params.Prompt
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
data[ParamKeys["steps"]] = params.Steps
data[ParamKeys["sampler"]] = params.Sampler
// @fix bug: 有些 stable diffusion 没有面部修复功能
//data[ParamKeys["face_fix"]] = params.FaceFix
data[ParamKeys["cfg_scale"]] = params.CfgScale
data[ParamKeys["seed"]] = params.Seed
data[ParamKeys["height"]] = params.Height
data[ParamKeys["width"]] = params.Width
data[ParamKeys["hd_fix"]] = params.HdFix
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
data[ParamKeys["hd_scale"]] = params.HdScale
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
data[ParamKeys["hd_sample_num"]] = params.HdSteps
taskInfo.SessionId = task.SessionId
taskInfo.TaskId = params.TaskId
taskInfo.Data = data
taskInfo.JobId = task.Id
taskInfo.UserId = uint(task.UserId)
if task.Params.HdFix {
body.EnableHr = true
body.HrScale = task.Params.HdScale
body.HrUpscaler = task.Params.HdScaleAlg
body.HrSecondPassSteps = task.Params.HdSteps
body.DenoisingStrength = task.Params.HdRedrawRate
}
var res Txt2ImgResp
var errChan = make(chan error)
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
logger.Debugf("send image request to %s", apiURL)
go func() {
s.runTask(taskInfo, s.httpClient)
}()
return nil
}
// 执行任务
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
body := map[string]any{
"data": taskInfo.Data,
"event_data": taskInfo.EventData,
"fn_index": taskInfo.FnIndex,
"session_hash": taskInfo.SessionHash,
}
var result = make(chan CBReq)
go func() {
var res struct {
Data []interface{} `json:"data"`
IsGenerating bool `json:"is_generating"`
Duration float64 `json:"duration"`
AverageDuration float64 `json:"average_duration"`
}
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
if err != nil {
cbReq.Message = "error with send request: " + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- err
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
cbReq.Message = "error http status code: " + string(bytes)
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error http code status: %v", response.Status)
return
}
var images []struct {
Name string `json:"name"`
Data interface{} `json:"data"`
IsFile bool `json:"is_file"`
}
err = utils.ForceCovert(res.Data[0], &images)
// 保存 Base64 图片
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
if err != nil {
cbReq.Message = "error with decode image:" + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error with upload image: %v", err)
return
}
var info map[string]any
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
// 获取绘画真实的 seed
var info map[string]interface{}
err = utils.JsonDecode(res.Info, &info)
if err != nil {
logger.Error(res.Data)
cbReq.Message = "error with decode image url:" + err.Error()
cbReq.Success = false
result <- cbReq
errChan <- fmt.Errorf("error with decode task response: %v", err)
return
}
// 获取真实的 seed 值
cbReq.ImageName = images[0].Name
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
cbReq.Seed = seed
cbReq.Success = true
cbReq.Progress = 100
result <- cbReq
close(result)
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
errChan <- nil
}()
for {
select {
case value := <-result:
s.callback(value)
return
case err := <-errChan: // 任务完成
if err != nil {
return err
}
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(task.UserId)
return nil
default:
var progressReq = map[string]any{
"id_task": taskInfo.TaskId,
"id_live_preview": 1,
err, resp := s.checkTaskProgress()
// 更新任务进度
if err == nil && resp.Progress > 0 {
logger.Debugf("Check task progress: %+v", resp.Progress)
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(task.UserId)
}
var progressRes struct {
Active bool `json:"active"`
Queued bool `json:"queued"`
Completed bool `json:"completed"`
Progress float64 `json:"progress"`
Eta float64 `json:"eta"`
LivePreview string `json:"live_preview"`
IDLivePreview int `json:"id_live_preview"`
TextInfo interface{} `json:"textinfo"`
}
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
if err != nil { // TODO: 这里可以考虑设置失败重试次数
logger.Error(err)
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
logger.Error(string(bytes))
return
}
cbReq.ImageData = progressRes.LivePreview
cbReq.Progress = int(progressRes.Progress * 100)
s.callback(cbReq)
time.Sleep(time.Second)
}
}
}
func (s *Service) callback(data CBReq) {
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
if data.Success { // 任务成功
var job model.SdJob
res := s.db.Where("id = ?", data.JobId).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
// 更新任务进度
job.Progress = data.Progress
// 更新任务 seed
var params types.SdTaskParams
err := utils.JsonDecode(job.Params, &params)
if err != nil {
logger.Error("任务解析失败:", err)
return
}
params.Seed = data.Seed
if data.ImageName != "" { // 下载图片
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
if data.Progress == 100 {
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imageURL
}
}
job.Params = utils.JsonEncode(params)
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
logger.Debugf("绘图进度:%d", data.Progress)
} else { // 任务失败
logger.Error("任务执行失败:", data.Message)
// update the task progress
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": data.Message,
})
// 执行任务
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
var res TaskProgressResp
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
if err != nil {
return err, nil
}
if response.IsErrorState() {
return fmt.Errorf("error http code status: %v", response.Status), nil
}
// 发送更新状态信号
s.notifyQueue.RPush(data.UserId)
return nil, &res
}

View File

@ -1,21 +1,5 @@
package main
import (
"chatplus/utils"
"fmt"
)
type Person struct {
Name string
Age int
}
type Student struct {
Person
School string
}
func main() {
fmt.Println(utils.RandString(64))
}