mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
feat: stable-diffusion refactored, replace websocket api with sdapi
This commit is contained in:
parent
49b5906bc7
commit
342b76f666
@ -57,9 +57,9 @@ type MidJourneyConfig struct {
|
|||||||
|
|
||||||
type StableDiffusionConfig struct {
|
type StableDiffusionConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
|
Model string // 模型名称
|
||||||
ApiURL string
|
ApiURL string
|
||||||
ApiKey string
|
ApiKey string
|
||||||
Txt2ImgJsonPath string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidJourneyPlusConfig struct {
|
type MidJourneyPlusConfig struct {
|
||||||
|
@ -175,6 +175,12 @@ func main() {
|
|||||||
|
|
||||||
// Stable Diffusion 机器人
|
// Stable Diffusion 机器人
|
||||||
fx.Provide(sd.NewServicePool),
|
fx.Provide(sd.NewServicePool),
|
||||||
|
fx.Invoke(func(pool *sd.ServicePool) {
|
||||||
|
if pool.HasAvailableService() {
|
||||||
|
pool.CheckTaskNotify()
|
||||||
|
pool.CheckTaskStatus()
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
fx.Provide(payment.NewAlipayService),
|
fx.Provide(payment.NewAlipayService),
|
||||||
fx.Provide(payment.NewHuPiPay),
|
fx.Provide(payment.NewHuPiPay),
|
||||||
|
@ -1,80 +0,0 @@
|
|||||||
{
|
|
||||||
"data": [
|
|
||||||
"task(cxvkpawy8onnfti)",
|
|
||||||
"a cute girl",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
20,
|
|
||||||
"DPM++ 2M Karras",
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
7,
|
|
||||||
512,
|
|
||||||
512,
|
|
||||||
false,
|
|
||||||
0.7,
|
|
||||||
2,
|
|
||||||
"Latent",
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
"Use same checkpoint",
|
|
||||||
"Use same sampler",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"None",
|
|
||||||
false,
|
|
||||||
"",
|
|
||||||
0.8,
|
|
||||||
-1,
|
|
||||||
false,
|
|
||||||
-1,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"positive",
|
|
||||||
"comma",
|
|
||||||
0,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"",
|
|
||||||
"Seed",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"Nothing",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
"Nothing",
|
|
||||||
"",
|
|
||||||
[],
|
|
||||||
true,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
0,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
50,
|
|
||||||
[],
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
""
|
|
||||||
],
|
|
||||||
"event_data": null,
|
|
||||||
"fn_index": 446,
|
|
||||||
"session_hash": "nk5noh1rz1o"
|
|
||||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -101,6 +102,20 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
// 上传文件字节数据
|
||||||
|
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s AliYunOss) Delete(fileURL string) error {
|
func (s AliYunOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
|
@ -3,13 +3,13 @@ package oss
|
|||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LocalStorage struct {
|
type LocalStorage struct {
|
||||||
@ -73,6 +73,20 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
|
||||||
|
err = os.WriteFile(filePath, imageData, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error writing to file:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s LocalStorage) Delete(fileURL string) error {
|
func (s LocalStorage) Delete(fileURL string) error {
|
||||||
if _, err := os.Stat(fileURL); err == nil {
|
if _, err := os.Stat(fileURL); err == nil {
|
||||||
return os.Remove(fileURL)
|
return os.Remove(fileURL)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -96,6 +97,25 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
info, err := s.client.PutObject(
|
||||||
|
context.Background(),
|
||||||
|
s.config.Bucket,
|
||||||
|
objectKey,
|
||||||
|
strings.NewReader(string(imageData)),
|
||||||
|
int64(len(imageData)),
|
||||||
|
minio.PutObjectOptions{ContentType: "image/png"})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s MiniOss) Delete(fileURL string) error {
|
func (s MiniOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -112,6 +113,22 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
|||||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||||
|
}
|
||||||
|
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||||
|
ret := storage.PutRet{}
|
||||||
|
extra := storage.PutExtra{}
|
||||||
|
// 上传文件字节数据
|
||||||
|
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s QinNiuOss) Delete(fileURL string) error {
|
func (s QinNiuOss) Delete(fileURL string) error {
|
||||||
var objectKey string
|
var objectKey string
|
||||||
if strings.HasPrefix(fileURL, "http") {
|
if strings.HasPrefix(fileURL, "http") {
|
||||||
|
@ -17,5 +17,6 @@ type File struct {
|
|||||||
type Uploader interface {
|
type Uploader interface {
|
||||||
PutFile(ctx *gin.Context, name string) (File, error)
|
PutFile(ctx *gin.Context, name string) (File, error)
|
||||||
PutImg(imageURL string, useProxy bool) (string, error)
|
PutImg(imageURL string, useProxy bool) (string, error)
|
||||||
|
PutBase64(imageData string) (string, error)
|
||||||
Delete(fileURL string) error
|
Delete(fileURL string) error
|
||||||
}
|
}
|
||||||
|
@ -25,14 +25,14 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
|||||||
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
||||||
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
||||||
// create mj client and service
|
// create mj client and service
|
||||||
for k, config := range appConfig.SdConfigs {
|
for _, config := range appConfig.SdConfigs {
|
||||||
if config.Enabled == false {
|
if config.Enabled == false {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// create sd service
|
// create sd service
|
||||||
name := fmt.Sprintf("StableDifffusion Service-%d", k)
|
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
|
||||||
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
|
service := NewService(name, config, taskQueue, notifyQueue, db, manager)
|
||||||
// run sd service
|
// run sd service
|
||||||
go func() {
|
go func() {
|
||||||
service.Run()
|
service.Run()
|
||||||
@ -58,6 +58,7 @@ func (p *ServicePool) PushTask(task types.SdTask) {
|
|||||||
|
|
||||||
func (p *ServicePool) CheckTaskNotify() {
|
func (p *ServicePool) CheckTaskNotify() {
|
||||||
go func() {
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||||
for {
|
for {
|
||||||
var userId uint
|
var userId uint
|
||||||
err := p.notifyQueue.LPop(&userId)
|
err := p.notifyQueue.LPop(&userId)
|
||||||
@ -79,6 +80,7 @@ func (p *ServicePool) CheckTaskNotify() {
|
|||||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||||
func (p *ServicePool) CheckTaskStatus() {
|
func (p *ServicePool) CheckTaskStatus() {
|
||||||
go func() {
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||||
for {
|
for {
|
||||||
var jobs []model.SdJob
|
var jobs []model.SdJob
|
||||||
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
||||||
|
@ -6,16 +6,11 @@ import (
|
|||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SD 绘画服务
|
// SD 绘画服务
|
||||||
@ -28,13 +23,10 @@ type Service struct {
|
|||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
name string // service name
|
name string // service name
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
|
||||||
handledTaskNum int32 // already handled task number
|
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
|
||||||
taskTimeout int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||||
|
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
name: name,
|
||||||
config: config,
|
config: config,
|
||||||
@ -43,22 +35,11 @@ func NewService(name string, maxTaskNum int32, timeout int64, config types.Stabl
|
|||||||
notifyQueue: notifyQueue,
|
notifyQueue: notifyQueue,
|
||||||
db: db,
|
db: db,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
taskTimeout: timeout,
|
|
||||||
maxHandleTaskNum: maxTaskNum,
|
|
||||||
taskStartTimes: make(map[int]time.Time),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
for {
|
for {
|
||||||
s.checkTasks()
|
|
||||||
if !s.canHandleTask() {
|
|
||||||
// current service is full, can not handle more task
|
|
||||||
// waiting for running task finish
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.SdTask
|
var task types.SdTask
|
||||||
err := s.taskQueue.LPop(&task)
|
err := s.taskQueue.LPop(&task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -74,239 +55,135 @@ func (s *Service) Run() {
|
|||||||
"progress": -1,
|
"progress": -1,
|
||||||
"err_msg": err.Error(),
|
"err_msg": err.Error(),
|
||||||
})
|
})
|
||||||
// release task num
|
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
|
||||||
// 通知前端,任务失败
|
// 通知前端,任务失败
|
||||||
s.notifyQueue.RPush(task.UserId)
|
s.notifyQueue.RPush(task.UserId)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock the task until the execute timeout
|
|
||||||
s.taskStartTimes[task.Id] = time.Now()
|
|
||||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if current service instance can handle more task
|
// Txt2ImgReq 文生图请求实体
|
||||||
func (s *Service) canHandleTask() bool {
|
type Txt2ImgReq struct {
|
||||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
Prompt string `json:"prompt"`
|
||||||
return handledNum < s.maxHandleTaskNum
|
NegativePrompt string `json:"negative_prompt"`
|
||||||
|
Seed int64 `json:"seed"`
|
||||||
|
Steps int `json:"steps"`
|
||||||
|
CfgScale float32 `json:"cfg_scale"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
SamplerName string `json:"sampler_name"`
|
||||||
|
EnableHr bool `json:"enable_hr,omitempty"`
|
||||||
|
HrScale int `json:"hr_scale,omitempty"`
|
||||||
|
HrUpscaler string `json:"hr_upscaler,omitempty"`
|
||||||
|
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
|
||||||
|
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
|
||||||
|
ForceTaskId string `json:"force_task_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove the expired tasks
|
// Txt2ImgResp 文生图响应实体
|
||||||
func (s *Service) checkTasks() {
|
type Txt2ImgResp struct {
|
||||||
for k, t := range s.taskStartTimes {
|
Images []string `json:"images"`
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
Parameters struct {
|
||||||
delete(s.taskStartTimes, k)
|
} `json:"parameters"`
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
Info string `json:"info"`
|
||||||
// delete task from database
|
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TaskProgressResp 任务进度响应实体
|
||||||
|
type TaskProgressResp struct {
|
||||||
|
Progress float64 `json:"progress"`
|
||||||
|
EtaRelative float64 `json:"eta_relative"`
|
||||||
|
CurrentImage string `json:"current_image"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Txt2Img 文生图 API
|
// Txt2Img 文生图 API
|
||||||
func (s *Service) Txt2Img(task types.SdTask) error {
|
func (s *Service) Txt2Img(task types.SdTask) error {
|
||||||
var taskInfo TaskInfo
|
body := Txt2ImgReq{
|
||||||
bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath)
|
Prompt: task.Params.Prompt,
|
||||||
if err != nil {
|
NegativePrompt: task.Params.NegativePrompt,
|
||||||
return fmt.Errorf("error with load text2img json template file: %s", err.Error())
|
Steps: task.Params.Steps,
|
||||||
|
CfgScale: task.Params.CfgScale,
|
||||||
|
Width: task.Params.Width,
|
||||||
|
Height: task.Params.Height,
|
||||||
|
SamplerName: task.Params.Sampler,
|
||||||
}
|
}
|
||||||
|
if task.Params.Seed > 0 {
|
||||||
err = json.Unmarshal(bytes, &taskInfo)
|
body.Seed = task.Params.Seed
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode json params: %s", err.Error())
|
|
||||||
}
|
}
|
||||||
|
if task.Params.HdFix {
|
||||||
data := taskInfo.Data
|
body.EnableHr = true
|
||||||
params := task.Params
|
body.HrScale = task.Params.HdScale
|
||||||
data[ParamKeys["task_id"]] = params.TaskId
|
body.HrUpscaler = task.Params.HdScaleAlg
|
||||||
data[ParamKeys["prompt"]] = params.Prompt
|
body.HrSecondPassSteps = task.Params.HdSteps
|
||||||
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
body.DenoisingStrength = task.Params.HdRedrawRate
|
||||||
data[ParamKeys["steps"]] = params.Steps
|
}
|
||||||
data[ParamKeys["sampler"]] = params.Sampler
|
var res Txt2ImgResp
|
||||||
// @fix bug: 有些 stable diffusion 没有面部修复功能
|
var errChan = make(chan error)
|
||||||
//data[ParamKeys["face_fix"]] = params.FaceFix
|
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||||
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
logger.Debugf("send image request to %s", apiURL)
|
||||||
data[ParamKeys["seed"]] = params.Seed
|
|
||||||
data[ParamKeys["height"]] = params.Height
|
|
||||||
data[ParamKeys["width"]] = params.Width
|
|
||||||
data[ParamKeys["hd_fix"]] = params.HdFix
|
|
||||||
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
|
||||||
data[ParamKeys["hd_scale"]] = params.HdScale
|
|
||||||
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
|
||||||
data[ParamKeys["hd_sample_num"]] = params.HdSteps
|
|
||||||
|
|
||||||
taskInfo.SessionId = task.SessionId
|
|
||||||
taskInfo.TaskId = params.TaskId
|
|
||||||
taskInfo.Data = data
|
|
||||||
taskInfo.JobId = task.Id
|
|
||||||
taskInfo.UserId = uint(task.UserId)
|
|
||||||
go func() {
|
go func() {
|
||||||
s.runTask(taskInfo, s.httpClient)
|
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 执行任务
|
|
||||||
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
|
|
||||||
body := map[string]any{
|
|
||||||
"data": taskInfo.Data,
|
|
||||||
"event_data": taskInfo.EventData,
|
|
||||||
"fn_index": taskInfo.FnIndex,
|
|
||||||
"session_hash": taskInfo.SessionHash,
|
|
||||||
}
|
|
||||||
var result = make(chan CBReq)
|
|
||||||
go func() {
|
|
||||||
var res struct {
|
|
||||||
Data []interface{} `json:"data"`
|
|
||||||
IsGenerating bool `json:"is_generating"`
|
|
||||||
Duration float64 `json:"duration"`
|
|
||||||
AverageDuration float64 `json:"average_duration"`
|
|
||||||
}
|
|
||||||
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
|
||||||
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cbReq.Message = "error with send request: " + err.Error()
|
errChan <- err
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.IsErrorState() {
|
if response.IsErrorState() {
|
||||||
bytes, _ := io.ReadAll(response.Body)
|
errChan <- fmt.Errorf("error http code status: %v", response.Status)
|
||||||
cbReq.Message = "error http status code: " + string(bytes)
|
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var images []struct {
|
// 保存 Base64 图片
|
||||||
Name string `json:"name"`
|
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
||||||
Data interface{} `json:"data"`
|
|
||||||
IsFile bool `json:"is_file"`
|
|
||||||
}
|
|
||||||
err = utils.ForceCovert(res.Data[0], &images)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cbReq.Message = "error with decode image:" + err.Error()
|
errChan <- fmt.Errorf("error with upload image: %v", err)
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 获取绘画真实的 seed
|
||||||
var info map[string]any
|
var info map[string]interface{}
|
||||||
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
err = utils.JsonDecode(res.Info, &info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(res.Data)
|
errChan <- fmt.Errorf("error with decode task response: %v", err)
|
||||||
cbReq.Message = "error with decode image url:" + err.Error()
|
|
||||||
cbReq.Success = false
|
|
||||||
result <- cbReq
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
|
||||||
// 获取真实的 seed 值
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
|
||||||
cbReq.ImageName = images[0].Name
|
errChan <- nil
|
||||||
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
|
|
||||||
cbReq.Seed = seed
|
|
||||||
cbReq.Success = true
|
|
||||||
cbReq.Progress = 100
|
|
||||||
result <- cbReq
|
|
||||||
close(result)
|
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case value := <-result:
|
case err := <-errChan: // 任务完成
|
||||||
s.callback(value)
|
if err != nil {
|
||||||
return
|
return err
|
||||||
|
}
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||||
|
s.notifyQueue.RPush(task.UserId)
|
||||||
|
return nil
|
||||||
default:
|
default:
|
||||||
var progressReq = map[string]any{
|
err, resp := s.checkTaskProgress()
|
||||||
"id_task": taskInfo.TaskId,
|
// 更新任务进度
|
||||||
"id_live_preview": 1,
|
if err == nil && resp.Progress > 0 {
|
||||||
|
logger.Debugf("Check task progress: %+v", resp.Progress)
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||||
|
// 发送更新状态信号
|
||||||
|
s.notifyQueue.RPush(task.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
var progressRes struct {
|
|
||||||
Active bool `json:"active"`
|
|
||||||
Queued bool `json:"queued"`
|
|
||||||
Completed bool `json:"completed"`
|
|
||||||
Progress float64 `json:"progress"`
|
|
||||||
Eta float64 `json:"eta"`
|
|
||||||
LivePreview string `json:"live_preview"`
|
|
||||||
IDLivePreview int `json:"id_live_preview"`
|
|
||||||
TextInfo interface{} `json:"textinfo"`
|
|
||||||
}
|
|
||||||
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
|
|
||||||
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
|
||||||
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
|
||||||
logger.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.IsErrorState() {
|
|
||||||
bytes, _ := io.ReadAll(response.Body)
|
|
||||||
logger.Error(string(bytes))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cbReq.ImageData = progressRes.LivePreview
|
|
||||||
cbReq.Progress = int(progressRes.Progress * 100)
|
|
||||||
s.callback(cbReq)
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) callback(data CBReq) {
|
// 执行任务
|
||||||
// release task num
|
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||||
if data.Success { // 任务成功
|
var res TaskProgressResp
|
||||||
var job model.SdJob
|
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
|
||||||
res := s.db.Where("id = ?", data.JobId).First(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Warn("非法任务:", res.Error)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 更新任务进度
|
|
||||||
job.Progress = data.Progress
|
|
||||||
// 更新任务 seed
|
|
||||||
var params types.SdTaskParams
|
|
||||||
err := utils.JsonDecode(job.Params, ¶ms)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("任务解析失败:", err)
|
return err, nil
|
||||||
return
|
}
|
||||||
|
if response.IsErrorState() {
|
||||||
|
return fmt.Errorf("error http code status: %v", response.Status), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
params.Seed = data.Seed
|
return nil, &res
|
||||||
if data.ImageName != "" { // 下载图片
|
|
||||||
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
|
|
||||||
if data.Progress == 100 {
|
|
||||||
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download img: ", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
job.ImgURL = imageURL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
job.Params = utils.JsonEncode(params)
|
|
||||||
res = s.db.Updates(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("error with update job: ", res.Error)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("绘图进度:%d", data.Progress)
|
|
||||||
} else { // 任务失败
|
|
||||||
logger.Error("任务执行失败:", data.Message)
|
|
||||||
// update the task progress
|
|
||||||
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": -1,
|
|
||||||
"err_msg": data.Message,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送更新状态信号
|
|
||||||
s.notifyQueue.RPush(data.UserId)
|
|
||||||
}
|
}
|
||||||
|
@ -1,21 +1,5 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/utils"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Person struct {
|
|
||||||
Name string
|
|
||||||
Age int
|
|
||||||
}
|
|
||||||
|
|
||||||
type Student struct {
|
|
||||||
Person
|
|
||||||
School string
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
||||||
fmt.Println(utils.RandString(64))
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user