mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
feat: stable-diffusion refactored, replace websocket api with sdapi
This commit is contained in:
parent
49b5906bc7
commit
342b76f666
@ -56,10 +56,10 @@ type MidJourneyConfig struct {
|
||||
}
|
||||
|
||||
type StableDiffusionConfig struct {
|
||||
Enabled bool
|
||||
ApiURL string
|
||||
ApiKey string
|
||||
Txt2ImgJsonPath string
|
||||
Enabled bool
|
||||
Model string // 模型名称
|
||||
ApiURL string
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
type MidJourneyPlusConfig struct {
|
||||
|
@ -175,6 +175,12 @@ func main() {
|
||||
|
||||
// Stable Diffusion 机器人
|
||||
fx.Provide(sd.NewServicePool),
|
||||
fx.Invoke(func(pool *sd.ServicePool) {
|
||||
if pool.HasAvailableService() {
|
||||
pool.CheckTaskNotify()
|
||||
pool.CheckTaskStatus()
|
||||
}
|
||||
}),
|
||||
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
fx.Provide(payment.NewHuPiPay),
|
||||
|
@ -1,80 +0,0 @@
|
||||
{
|
||||
"data": [
|
||||
"task(cxvkpawy8onnfti)",
|
||||
"a cute girl",
|
||||
"",
|
||||
[],
|
||||
20,
|
||||
"DPM++ 2M Karras",
|
||||
1,
|
||||
1,
|
||||
7,
|
||||
512,
|
||||
512,
|
||||
false,
|
||||
0.7,
|
||||
2,
|
||||
"Latent",
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
"Use same checkpoint",
|
||||
"Use same sampler",
|
||||
"",
|
||||
"",
|
||||
[],
|
||||
"None",
|
||||
false,
|
||||
"",
|
||||
0.8,
|
||||
-1,
|
||||
false,
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
false,
|
||||
"positive",
|
||||
"comma",
|
||||
0,
|
||||
false,
|
||||
false,
|
||||
"",
|
||||
"Seed",
|
||||
"",
|
||||
[],
|
||||
"Nothing",
|
||||
"",
|
||||
[],
|
||||
"Nothing",
|
||||
"",
|
||||
[],
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
50,
|
||||
[],
|
||||
"",
|
||||
"",
|
||||
""
|
||||
],
|
||||
"event_data": null,
|
||||
"fn_index": 446,
|
||||
"session_hash": "nk5noh1rz1o"
|
||||
}
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
@ -101,6 +102,20 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
|
@ -3,13 +3,13 @@ package oss
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
@ -73,6 +73,20 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
|
||||
err = os.WriteFile(filePath, imageData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing to file:%v", err)
|
||||
}
|
||||
|
||||
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) Delete(fileURL string) error {
|
||||
if _, err := os.Stat(fileURL); err == nil {
|
||||
return os.Remove(fileURL)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
@ -96,6 +97,25 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
objectKey,
|
||||
strings.NewReader(string(imageData)),
|
||||
int64(len(imageData)),
|
||||
minio.PutObjectOptions{ContentType: "image/png"})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
|
||||
}
|
||||
|
||||
func (s MiniOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
@ -112,6 +113,22 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||
}
|
||||
|
||||
func (s QinNiuOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
|
@ -17,5 +17,6 @@ type File struct {
|
||||
type Uploader interface {
|
||||
PutFile(ctx *gin.Context, name string) (File, error)
|
||||
PutImg(imageURL string, useProxy bool) (string, error)
|
||||
PutBase64(imageData string) (string, error)
|
||||
Delete(fileURL string) error
|
||||
}
|
||||
|
@ -25,14 +25,14 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
||||
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
||||
// create mj client and service
|
||||
for k, config := range appConfig.SdConfigs {
|
||||
for _, config := range appConfig.SdConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
|
||||
// create sd service
|
||||
name := fmt.Sprintf("StableDifffusion Service-%d", k)
|
||||
service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
|
||||
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
|
||||
service := NewService(name, config, taskQueue, notifyQueue, db, manager)
|
||||
// run sd service
|
||||
go func() {
|
||||
service.Run()
|
||||
@ -58,6 +58,7 @@ func (p *ServicePool) PushTask(task types.SdTask) {
|
||||
|
||||
func (p *ServicePool) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||
for {
|
||||
var userId uint
|
||||
err := p.notifyQueue.LPop(&userId)
|
||||
@ -79,6 +80,7 @@ func (p *ServicePool) CheckTaskNotify() {
|
||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||
func (p *ServicePool) CheckTaskStatus() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||
for {
|
||||
var jobs []model.SdJob
|
||||
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
||||
|
@ -6,59 +6,40 @@ import (
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
config types.StableDiffusionConfig
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
name string // service name
|
||||
maxHandleTaskNum int32 // max task number current service can handle
|
||||
handledTaskNum int32 // already handled task number
|
||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||
taskTimeout int64
|
||||
httpClient *req.Client
|
||||
config types.StableDiffusionConfig
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
name string // service name
|
||||
}
|
||||
|
||||
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
||||
return &Service{
|
||||
name: name,
|
||||
config: config,
|
||||
httpClient: req.C(),
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
db: db,
|
||||
uploadManager: manager,
|
||||
taskTimeout: timeout,
|
||||
maxHandleTaskNum: maxTaskNum,
|
||||
taskStartTimes: make(map[int]time.Time),
|
||||
name: name,
|
||||
config: config,
|
||||
httpClient: req.C(),
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
db: db,
|
||||
uploadManager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
for {
|
||||
s.checkTasks()
|
||||
if !s.canHandleTask() {
|
||||
// current service is full, can not handle more task
|
||||
// waiting for running task finish
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
var task types.SdTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
@ -74,239 +55,135 @@ func (s *Service) Run() {
|
||||
"progress": -1,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
// release task num
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
// 通知前端,任务失败
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
continue
|
||||
}
|
||||
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[task.Id] = time.Now()
|
||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// check if current service instance can handle more task
|
||||
func (s *Service) canHandleTask() bool {
|
||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
||||
return handledNum < s.maxHandleTaskNum
|
||||
// Txt2ImgReq 文生图请求实体
|
||||
type Txt2ImgReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt"`
|
||||
Seed int64 `json:"seed"`
|
||||
Steps int `json:"steps"`
|
||||
CfgScale float32 `json:"cfg_scale"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
SamplerName string `json:"sampler_name"`
|
||||
EnableHr bool `json:"enable_hr,omitempty"`
|
||||
HrScale int `json:"hr_scale,omitempty"`
|
||||
HrUpscaler string `json:"hr_upscaler,omitempty"`
|
||||
HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"`
|
||||
DenoisingStrength float32 `json:"denoising_strength,omitempty"`
|
||||
ForceTaskId string `json:"force_task_id,omitempty"`
|
||||
}
|
||||
|
||||
// remove the expired tasks
|
||||
func (s *Service) checkTasks() {
|
||||
for k, t := range s.taskStartTimes {
|
||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||
delete(s.taskStartTimes, k)
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
// delete task from database
|
||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
||||
}
|
||||
}
|
||||
// Txt2ImgResp 文生图响应实体
|
||||
type Txt2ImgResp struct {
|
||||
Images []string `json:"images"`
|
||||
Parameters struct {
|
||||
} `json:"parameters"`
|
||||
Info string `json:"info"`
|
||||
}
|
||||
|
||||
// TaskProgressResp 任务进度响应实体
|
||||
type TaskProgressResp struct {
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
CurrentImage string `json:"current_image"`
|
||||
}
|
||||
|
||||
// Txt2Img 文生图 API
|
||||
func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
var taskInfo TaskInfo
|
||||
bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with load text2img json template file: %s", err.Error())
|
||||
body := Txt2ImgReq{
|
||||
Prompt: task.Params.Prompt,
|
||||
NegativePrompt: task.Params.NegativePrompt,
|
||||
Steps: task.Params.Steps,
|
||||
CfgScale: task.Params.CfgScale,
|
||||
Width: task.Params.Width,
|
||||
Height: task.Params.Height,
|
||||
SamplerName: task.Params.Sampler,
|
||||
}
|
||||
|
||||
err = json.Unmarshal(bytes, &taskInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with decode json params: %s", err.Error())
|
||||
if task.Params.Seed > 0 {
|
||||
body.Seed = task.Params.Seed
|
||||
}
|
||||
|
||||
data := taskInfo.Data
|
||||
params := task.Params
|
||||
data[ParamKeys["task_id"]] = params.TaskId
|
||||
data[ParamKeys["prompt"]] = params.Prompt
|
||||
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
||||
data[ParamKeys["steps"]] = params.Steps
|
||||
data[ParamKeys["sampler"]] = params.Sampler
|
||||
// @fix bug: 有些 stable diffusion 没有面部修复功能
|
||||
//data[ParamKeys["face_fix"]] = params.FaceFix
|
||||
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
||||
data[ParamKeys["seed"]] = params.Seed
|
||||
data[ParamKeys["height"]] = params.Height
|
||||
data[ParamKeys["width"]] = params.Width
|
||||
data[ParamKeys["hd_fix"]] = params.HdFix
|
||||
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
||||
data[ParamKeys["hd_scale"]] = params.HdScale
|
||||
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
||||
data[ParamKeys["hd_sample_num"]] = params.HdSteps
|
||||
|
||||
taskInfo.SessionId = task.SessionId
|
||||
taskInfo.TaskId = params.TaskId
|
||||
taskInfo.Data = data
|
||||
taskInfo.JobId = task.Id
|
||||
taskInfo.UserId = uint(task.UserId)
|
||||
if task.Params.HdFix {
|
||||
body.EnableHr = true
|
||||
body.HrScale = task.Params.HdScale
|
||||
body.HrUpscaler = task.Params.HdScaleAlg
|
||||
body.HrSecondPassSteps = task.Params.HdSteps
|
||||
body.DenoisingStrength = task.Params.HdRedrawRate
|
||||
}
|
||||
var res Txt2ImgResp
|
||||
var errChan = make(chan error)
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||
logger.Debugf("send image request to %s", apiURL)
|
||||
go func() {
|
||||
s.runTask(taskInfo, s.httpClient)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
|
||||
body := map[string]any{
|
||||
"data": taskInfo.Data,
|
||||
"event_data": taskInfo.EventData,
|
||||
"fn_index": taskInfo.FnIndex,
|
||||
"session_hash": taskInfo.SessionHash,
|
||||
}
|
||||
var result = make(chan CBReq)
|
||||
go func() {
|
||||
var res struct {
|
||||
Data []interface{} `json:"data"`
|
||||
IsGenerating bool `json:"is_generating"`
|
||||
Duration float64 `json:"duration"`
|
||||
AverageDuration float64 `json:"average_duration"`
|
||||
}
|
||||
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
||||
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
|
||||
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
|
||||
if err != nil {
|
||||
cbReq.Message = "error with send request: " + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
bytes, _ := io.ReadAll(response.Body)
|
||||
cbReq.Message = "error http status code: " + string(bytes)
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
errChan <- fmt.Errorf("error http code status: %v", response.Status)
|
||||
return
|
||||
}
|
||||
|
||||
var images []struct {
|
||||
Name string `json:"name"`
|
||||
Data interface{} `json:"data"`
|
||||
IsFile bool `json:"is_file"`
|
||||
}
|
||||
err = utils.ForceCovert(res.Data[0], &images)
|
||||
// 保存 Base64 图片
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
||||
if err != nil {
|
||||
cbReq.Message = "error with decode image:" + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
errChan <- fmt.Errorf("error with upload image: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var info map[string]any
|
||||
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
||||
// 获取绘画真实的 seed
|
||||
var info map[string]interface{}
|
||||
err = utils.JsonDecode(res.Info, &info)
|
||||
if err != nil {
|
||||
logger.Error(res.Data)
|
||||
cbReq.Message = "error with decode image url:" + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
errChan <- fmt.Errorf("error with decode task response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取真实的 seed 值
|
||||
cbReq.ImageName = images[0].Name
|
||||
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
|
||||
cbReq.Seed = seed
|
||||
cbReq.Success = true
|
||||
cbReq.Progress = 100
|
||||
result <- cbReq
|
||||
close(result)
|
||||
|
||||
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
|
||||
errChan <- nil
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case value := <-result:
|
||||
s.callback(value)
|
||||
return
|
||||
case err := <-errChan: // 任务完成
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
return nil
|
||||
default:
|
||||
var progressReq = map[string]any{
|
||||
"id_task": taskInfo.TaskId,
|
||||
"id_live_preview": 1,
|
||||
err, resp := s.checkTaskProgress()
|
||||
// 更新任务进度
|
||||
if err == nil && resp.Progress > 0 {
|
||||
logger.Debugf("Check task progress: %+v", resp.Progress)
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
}
|
||||
|
||||
var progressRes struct {
|
||||
Active bool `json:"active"`
|
||||
Queued bool `json:"queued"`
|
||||
Completed bool `json:"completed"`
|
||||
Progress float64 `json:"progress"`
|
||||
Eta float64 `json:"eta"`
|
||||
LivePreview string `json:"live_preview"`
|
||||
IDLivePreview int `json:"id_live_preview"`
|
||||
TextInfo interface{} `json:"textinfo"`
|
||||
}
|
||||
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
|
||||
var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
||||
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
bytes, _ := io.ReadAll(response.Body)
|
||||
logger.Error(string(bytes))
|
||||
return
|
||||
}
|
||||
|
||||
cbReq.ImageData = progressRes.LivePreview
|
||||
cbReq.Progress = int(progressRes.Progress * 100)
|
||||
s.callback(cbReq)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *Service) callback(data CBReq) {
|
||||
// release task num
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
if data.Success { // 任务成功
|
||||
var job model.SdJob
|
||||
res := s.db.Where("id = ?", data.JobId).First(&job)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", res.Error)
|
||||
return
|
||||
}
|
||||
// 更新任务进度
|
||||
job.Progress = data.Progress
|
||||
// 更新任务 seed
|
||||
var params types.SdTaskParams
|
||||
err := utils.JsonDecode(job.Params, ¶ms)
|
||||
if err != nil {
|
||||
logger.Error("任务解析失败:", err)
|
||||
return
|
||||
}
|
||||
|
||||
params.Seed = data.Seed
|
||||
if data.ImageName != "" { // 下载图片
|
||||
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
|
||||
if data.Progress == 100 {
|
||||
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
}
|
||||
job.ImgURL = imageURL
|
||||
}
|
||||
}
|
||||
|
||||
job.Params = utils.JsonEncode(params)
|
||||
res = s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update job: ", res.Error)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("绘图进度:%d", data.Progress)
|
||||
} else { // 任务失败
|
||||
logger.Error("任务执行失败:", data.Message)
|
||||
// update the task progress
|
||||
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": data.Message,
|
||||
})
|
||||
// 执行任务
|
||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||
var res TaskProgressResp
|
||||
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if response.IsErrorState() {
|
||||
return fmt.Errorf("error http code status: %v", response.Status), nil
|
||||
}
|
||||
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(data.UserId)
|
||||
return nil, &res
|
||||
}
|
||||
|
@ -1,21 +1,5 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
type Student struct {
|
||||
Person
|
||||
School string
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
fmt.Println(utils.RandString(64))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user