mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: stable-diffusion refactored, replace websocket api with sdapi
This commit is contained in:
		@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
@@ -101,6 +102,20 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
 | 
			
		||||
	imageData, err := base64.StdEncoding.DecodeString(base64Img)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error decoding base64:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
 | 
			
		||||
	// 上传文件字节数据
 | 
			
		||||
	err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AliYunOss) Delete(fileURL string) error {
 | 
			
		||||
	var objectKey string
 | 
			
		||||
	if strings.HasPrefix(fileURL, "http") {
 | 
			
		||||
 
 | 
			
		||||
@@ -3,13 +3,13 @@ package oss
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type LocalStorage struct {
 | 
			
		||||
@@ -73,6 +73,20 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
 | 
			
		||||
	imageData, err := base64.StdEncoding.DecodeString(base64Img)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error decoding base64:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
 | 
			
		||||
	err = os.WriteFile(filePath, imageData, 0644)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error writing to file:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s LocalStorage) Delete(fileURL string) error {
 | 
			
		||||
	if _, err := os.Stat(fileURL); err == nil {
 | 
			
		||||
		return os.Remove(fileURL)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
@@ -96,6 +97,25 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s MiniOss) PutBase64(base64Img string) (string, error) {
 | 
			
		||||
	imageData, err := base64.StdEncoding.DecodeString(base64Img)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error decoding base64:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
 | 
			
		||||
	info, err := s.client.PutObject(
 | 
			
		||||
		context.Background(),
 | 
			
		||||
		s.config.Bucket,
 | 
			
		||||
		objectKey,
 | 
			
		||||
		strings.NewReader(string(imageData)),
 | 
			
		||||
		int64(len(imageData)),
 | 
			
		||||
		minio.PutObjectOptions{ContentType: "image/png"})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s MiniOss) Delete(fileURL string) error {
 | 
			
		||||
	var objectKey string
 | 
			
		||||
	if strings.HasPrefix(fileURL, "http") {
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
@@ -112,6 +113,22 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
 | 
			
		||||
	imageData, err := base64.StdEncoding.DecodeString(base64Img)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error decoding base64:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
 | 
			
		||||
	ret := storage.PutRet{}
 | 
			
		||||
	extra := storage.PutExtra{}
 | 
			
		||||
	// 上传文件字节数据
 | 
			
		||||
	err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s QinNiuOss) Delete(fileURL string) error {
 | 
			
		||||
	var objectKey string
 | 
			
		||||
	if strings.HasPrefix(fileURL, "http") {
 | 
			
		||||
 
 | 
			
		||||
@@ -17,5 +17,6 @@ type File struct {
 | 
			
		||||
type Uploader interface {
 | 
			
		||||
	PutFile(ctx *gin.Context, name string) (File, error)
 | 
			
		||||
	PutImg(imageURL string, useProxy bool) (string, error)
 | 
			
		||||
	PutBase64(imageData string) (string, error)
 | 
			
		||||
	Delete(fileURL string) error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,14 +25,14 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
 | 
			
		||||
	taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
 | 
			
		||||
	notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
 | 
			
		||||
	// create mj client and service
 | 
			
		||||
	for k, config := range appConfig.SdConfigs {
 | 
			
		||||
	for _, config := range appConfig.SdConfigs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// create sd service
 | 
			
		||||
		name := fmt.Sprintf("StableDifffusion Service-%d", k)
 | 
			
		||||
		service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager)
 | 
			
		||||
		name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
 | 
			
		||||
		service := NewService(name, config, taskQueue, notifyQueue, db, manager)
 | 
			
		||||
		// run sd service
 | 
			
		||||
		go func() {
 | 
			
		||||
			service.Run()
 | 
			
		||||
@@ -58,6 +58,7 @@ func (p *ServicePool) PushTask(task types.SdTask) {
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var userId uint
 | 
			
		||||
			err := p.notifyQueue.LPop(&userId)
 | 
			
		||||
@@ -79,6 +80,7 @@ func (p *ServicePool) CheckTaskNotify() {
 | 
			
		||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
 | 
			
		||||
func (p *ServicePool) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.SdJob
 | 
			
		||||
			res := p.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,59 +6,40 @@ import (
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SD 绘画服务
 | 
			
		||||
 | 
			
		||||
type Service struct {
 | 
			
		||||
	httpClient       *req.Client
 | 
			
		||||
	config           types.StableDiffusionConfig
 | 
			
		||||
	taskQueue        *store.RedisQueue
 | 
			
		||||
	notifyQueue      *store.RedisQueue
 | 
			
		||||
	db               *gorm.DB
 | 
			
		||||
	uploadManager    *oss.UploaderManager
 | 
			
		||||
	name             string            // service name
 | 
			
		||||
	maxHandleTaskNum int32             // max task number current service can handle
 | 
			
		||||
	handledTaskNum   int32             // already handled task number
 | 
			
		||||
	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout
 | 
			
		||||
	taskTimeout      int64
 | 
			
		||||
	httpClient    *req.Client
 | 
			
		||||
	config        types.StableDiffusionConfig
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	notifyQueue   *store.RedisQueue
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	name          string // service name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
 | 
			
		||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
 | 
			
		||||
	config.ApiURL = strings.TrimRight(config.ApiURL, "/")
 | 
			
		||||
	return &Service{
 | 
			
		||||
		name:             name,
 | 
			
		||||
		config:           config,
 | 
			
		||||
		httpClient:       req.C(),
 | 
			
		||||
		taskQueue:        taskQueue,
 | 
			
		||||
		notifyQueue:      notifyQueue,
 | 
			
		||||
		db:               db,
 | 
			
		||||
		uploadManager:    manager,
 | 
			
		||||
		taskTimeout:      timeout,
 | 
			
		||||
		maxHandleTaskNum: maxTaskNum,
 | 
			
		||||
		taskStartTimes:   make(map[int]time.Time),
 | 
			
		||||
		name:          name,
 | 
			
		||||
		config:        config,
 | 
			
		||||
		httpClient:    req.C(),
 | 
			
		||||
		taskQueue:     taskQueue,
 | 
			
		||||
		notifyQueue:   notifyQueue,
 | 
			
		||||
		db:            db,
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	for {
 | 
			
		||||
		s.checkTasks()
 | 
			
		||||
		if !s.canHandleTask() {
 | 
			
		||||
			// current service is full, can not handle more task
 | 
			
		||||
			// waiting for running task finish
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var task types.SdTask
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -74,239 +55,135 @@ func (s *Service) Run() {
 | 
			
		||||
				"progress": -1,
 | 
			
		||||
				"err_msg":  err.Error(),
 | 
			
		||||
			})
 | 
			
		||||
			// release task num
 | 
			
		||||
			atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
			// 通知前端,任务失败
 | 
			
		||||
			s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// lock the task until the execute timeout
 | 
			
		||||
		s.taskStartTimes[task.Id] = time.Now()
 | 
			
		||||
		atomic.AddInt32(&s.handledTaskNum, 1)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// check if current service instance can handle more task
 | 
			
		||||
func (s *Service) canHandleTask() bool {
 | 
			
		||||
	handledNum := atomic.LoadInt32(&s.handledTaskNum)
 | 
			
		||||
	return handledNum < s.maxHandleTaskNum
 | 
			
		||||
// Txt2ImgReq 文生图请求实体
 | 
			
		||||
type Txt2ImgReq struct {
 | 
			
		||||
	Prompt            string  `json:"prompt"`
 | 
			
		||||
	NegativePrompt    string  `json:"negative_prompt"`
 | 
			
		||||
	Seed              int64   `json:"seed"`
 | 
			
		||||
	Steps             int     `json:"steps"`
 | 
			
		||||
	CfgScale          float32 `json:"cfg_scale"`
 | 
			
		||||
	Width             int     `json:"width"`
 | 
			
		||||
	Height            int     `json:"height"`
 | 
			
		||||
	SamplerName       string  `json:"sampler_name"`
 | 
			
		||||
	EnableHr          bool    `json:"enable_hr,omitempty"`
 | 
			
		||||
	HrScale           int     `json:"hr_scale,omitempty"`
 | 
			
		||||
	HrUpscaler        string  `json:"hr_upscaler,omitempty"`
 | 
			
		||||
	HrSecondPassSteps int     `json:"hr_second_pass_steps,omitempty"`
 | 
			
		||||
	DenoisingStrength float32 `json:"denoising_strength,omitempty"`
 | 
			
		||||
	ForceTaskId       string  `json:"force_task_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// remove the expired tasks
 | 
			
		||||
func (s *Service) checkTasks() {
 | 
			
		||||
	for k, t := range s.taskStartTimes {
 | 
			
		||||
		if time.Now().Unix()-t.Unix() > s.taskTimeout {
 | 
			
		||||
			delete(s.taskStartTimes, k)
 | 
			
		||||
			atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
			// delete task from database
 | 
			
		||||
			s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
// Txt2ImgResp 文生图响应实体
 | 
			
		||||
type Txt2ImgResp struct {
 | 
			
		||||
	Images     []string `json:"images"`
 | 
			
		||||
	Parameters struct {
 | 
			
		||||
	} `json:"parameters"`
 | 
			
		||||
	Info string `json:"info"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TaskProgressResp 任务进度响应实体
 | 
			
		||||
type TaskProgressResp struct {
 | 
			
		||||
	Progress     float64 `json:"progress"`
 | 
			
		||||
	EtaRelative  float64 `json:"eta_relative"`
 | 
			
		||||
	CurrentImage string  `json:"current_image"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Txt2Img 文生图 API
 | 
			
		||||
func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
	var taskInfo TaskInfo
 | 
			
		||||
	bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with load text2img json template file: %s", err.Error())
 | 
			
		||||
	body := Txt2ImgReq{
 | 
			
		||||
		Prompt:         task.Params.Prompt,
 | 
			
		||||
		NegativePrompt: task.Params.NegativePrompt,
 | 
			
		||||
		Steps:          task.Params.Steps,
 | 
			
		||||
		CfgScale:       task.Params.CfgScale,
 | 
			
		||||
		Width:          task.Params.Width,
 | 
			
		||||
		Height:         task.Params.Height,
 | 
			
		||||
		SamplerName:    task.Params.Sampler,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = json.Unmarshal(bytes, &taskInfo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with decode json params: %s", err.Error())
 | 
			
		||||
	if task.Params.Seed > 0 {
 | 
			
		||||
		body.Seed = task.Params.Seed
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := taskInfo.Data
 | 
			
		||||
	params := task.Params
 | 
			
		||||
	data[ParamKeys["task_id"]] = params.TaskId
 | 
			
		||||
	data[ParamKeys["prompt"]] = params.Prompt
 | 
			
		||||
	data[ParamKeys["negative_prompt"]] = params.NegativePrompt
 | 
			
		||||
	data[ParamKeys["steps"]] = params.Steps
 | 
			
		||||
	data[ParamKeys["sampler"]] = params.Sampler
 | 
			
		||||
	// @fix bug: 有些 stable diffusion 没有面部修复功能
 | 
			
		||||
	//data[ParamKeys["face_fix"]] = params.FaceFix
 | 
			
		||||
	data[ParamKeys["cfg_scale"]] = params.CfgScale
 | 
			
		||||
	data[ParamKeys["seed"]] = params.Seed
 | 
			
		||||
	data[ParamKeys["height"]] = params.Height
 | 
			
		||||
	data[ParamKeys["width"]] = params.Width
 | 
			
		||||
	data[ParamKeys["hd_fix"]] = params.HdFix
 | 
			
		||||
	data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
 | 
			
		||||
	data[ParamKeys["hd_scale"]] = params.HdScale
 | 
			
		||||
	data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
 | 
			
		||||
	data[ParamKeys["hd_sample_num"]] = params.HdSteps
 | 
			
		||||
 | 
			
		||||
	taskInfo.SessionId = task.SessionId
 | 
			
		||||
	taskInfo.TaskId = params.TaskId
 | 
			
		||||
	taskInfo.Data = data
 | 
			
		||||
	taskInfo.JobId = task.Id
 | 
			
		||||
	taskInfo.UserId = uint(task.UserId)
 | 
			
		||||
	if task.Params.HdFix {
 | 
			
		||||
		body.EnableHr = true
 | 
			
		||||
		body.HrScale = task.Params.HdScale
 | 
			
		||||
		body.HrUpscaler = task.Params.HdScaleAlg
 | 
			
		||||
		body.HrSecondPassSteps = task.Params.HdSteps
 | 
			
		||||
		body.DenoisingStrength = task.Params.HdRedrawRate
 | 
			
		||||
	}
 | 
			
		||||
	var res Txt2ImgResp
 | 
			
		||||
	var errChan = make(chan error)
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
 | 
			
		||||
	logger.Debugf("send image request to %s", apiURL)
 | 
			
		||||
	go func() {
 | 
			
		||||
		s.runTask(taskInfo, s.httpClient)
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行任务
 | 
			
		||||
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
 | 
			
		||||
	body := map[string]any{
 | 
			
		||||
		"data":         taskInfo.Data,
 | 
			
		||||
		"event_data":   taskInfo.EventData,
 | 
			
		||||
		"fn_index":     taskInfo.FnIndex,
 | 
			
		||||
		"session_hash": taskInfo.SessionHash,
 | 
			
		||||
	}
 | 
			
		||||
	var result = make(chan CBReq)
 | 
			
		||||
	go func() {
 | 
			
		||||
		var res struct {
 | 
			
		||||
			Data            []interface{} `json:"data"`
 | 
			
		||||
			IsGenerating    bool          `json:"is_generating"`
 | 
			
		||||
			Duration        float64       `json:"duration"`
 | 
			
		||||
			AverageDuration float64       `json:"average_duration"`
 | 
			
		||||
		}
 | 
			
		||||
		var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
 | 
			
		||||
		response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
 | 
			
		||||
		response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			cbReq.Message = "error with send request: " + err.Error()
 | 
			
		||||
			cbReq.Success = false
 | 
			
		||||
			result <- cbReq
 | 
			
		||||
			errChan <- err
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if response.IsErrorState() {
 | 
			
		||||
			bytes, _ := io.ReadAll(response.Body)
 | 
			
		||||
			cbReq.Message = "error http status code: " + string(bytes)
 | 
			
		||||
			cbReq.Success = false
 | 
			
		||||
			result <- cbReq
 | 
			
		||||
			errChan <- fmt.Errorf("error http code status: %v", response.Status)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var images []struct {
 | 
			
		||||
			Name   string      `json:"name"`
 | 
			
		||||
			Data   interface{} `json:"data"`
 | 
			
		||||
			IsFile bool        `json:"is_file"`
 | 
			
		||||
		}
 | 
			
		||||
		err = utils.ForceCovert(res.Data[0], &images)
 | 
			
		||||
		// 保存 Base64 图片
 | 
			
		||||
		imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			cbReq.Message = "error with decode image:" + err.Error()
 | 
			
		||||
			cbReq.Success = false
 | 
			
		||||
			result <- cbReq
 | 
			
		||||
			errChan <- fmt.Errorf("error with upload image: %v", err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var info map[string]any
 | 
			
		||||
		err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
 | 
			
		||||
		// 获取绘画真实的 seed
 | 
			
		||||
		var info map[string]interface{}
 | 
			
		||||
		err = utils.JsonDecode(res.Info, &info)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error(res.Data)
 | 
			
		||||
			cbReq.Message = "error with decode image url:" + err.Error()
 | 
			
		||||
			cbReq.Success = false
 | 
			
		||||
			result <- cbReq
 | 
			
		||||
			errChan <- fmt.Errorf("error with decode task response: %v", err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 获取真实的 seed 值
 | 
			
		||||
		cbReq.ImageName = images[0].Name
 | 
			
		||||
		seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
 | 
			
		||||
		cbReq.Seed = seed
 | 
			
		||||
		cbReq.Success = true
 | 
			
		||||
		cbReq.Progress = 100
 | 
			
		||||
		result <- cbReq
 | 
			
		||||
		close(result)
 | 
			
		||||
 | 
			
		||||
		task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
 | 
			
		||||
		s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
 | 
			
		||||
		errChan <- nil
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case value := <-result:
 | 
			
		||||
			s.callback(value)
 | 
			
		||||
			return
 | 
			
		||||
		case err := <-errChan: // 任务完成
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
 | 
			
		||||
			s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
			return nil
 | 
			
		||||
		default:
 | 
			
		||||
			var progressReq = map[string]any{
 | 
			
		||||
				"id_task":         taskInfo.TaskId,
 | 
			
		||||
				"id_live_preview": 1,
 | 
			
		||||
			err, resp := s.checkTaskProgress()
 | 
			
		||||
			// 更新任务进度
 | 
			
		||||
			if err == nil && resp.Progress > 0 {
 | 
			
		||||
				logger.Debugf("Check task progress: %+v", resp.Progress)
 | 
			
		||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
 | 
			
		||||
				// 发送更新状态信号
 | 
			
		||||
				s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var progressRes struct {
 | 
			
		||||
				Active        bool        `json:"active"`
 | 
			
		||||
				Queued        bool        `json:"queued"`
 | 
			
		||||
				Completed     bool        `json:"completed"`
 | 
			
		||||
				Progress      float64     `json:"progress"`
 | 
			
		||||
				Eta           float64     `json:"eta"`
 | 
			
		||||
				LivePreview   string      `json:"live_preview"`
 | 
			
		||||
				IDLivePreview int         `json:"id_live_preview"`
 | 
			
		||||
				TextInfo      interface{} `json:"textinfo"`
 | 
			
		||||
			}
 | 
			
		||||
			response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
 | 
			
		||||
			var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
 | 
			
		||||
			if err != nil { // TODO: 这里可以考虑设置失败重试次数
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if response.IsErrorState() {
 | 
			
		||||
				bytes, _ := io.ReadAll(response.Body)
 | 
			
		||||
				logger.Error(string(bytes))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			cbReq.ImageData = progressRes.LivePreview
 | 
			
		||||
			cbReq.Progress = int(progressRes.Progress * 100)
 | 
			
		||||
			s.callback(cbReq)
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) callback(data CBReq) {
 | 
			
		||||
	// release task num
 | 
			
		||||
	atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
	if data.Success { // 任务成功
 | 
			
		||||
		var job model.SdJob
 | 
			
		||||
		res := s.db.Where("id = ?", data.JobId).First(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Warn("非法任务:", res.Error)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// 更新任务进度
 | 
			
		||||
		job.Progress = data.Progress
 | 
			
		||||
		// 更新任务 seed
 | 
			
		||||
		var params types.SdTaskParams
 | 
			
		||||
		err := utils.JsonDecode(job.Params, ¶ms)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("任务解析失败:", err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		params.Seed = data.Seed
 | 
			
		||||
		if data.ImageName != "" { // 下载图片
 | 
			
		||||
			job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
 | 
			
		||||
			if data.Progress == 100 {
 | 
			
		||||
				imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error("error with download img: ", err.Error())
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				job.ImgURL = imageURL
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		job.Params = utils.JsonEncode(params)
 | 
			
		||||
		res = s.db.Updates(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update job: ", res.Error)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Debugf("绘图进度:%d", data.Progress)
 | 
			
		||||
	} else { // 任务失败
 | 
			
		||||
		logger.Error("任务执行失败:", data.Message)
 | 
			
		||||
		// update the task progress
 | 
			
		||||
		s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
			"progress": -1,
 | 
			
		||||
			"err_msg":  data.Message,
 | 
			
		||||
		})
 | 
			
		||||
// 执行任务
 | 
			
		||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
 | 
			
		||||
	var res TaskProgressResp
 | 
			
		||||
	response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil
 | 
			
		||||
	}
 | 
			
		||||
	if response.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error http code status: %v", response.Status), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 发送更新状态信号
 | 
			
		||||
	s.notifyQueue.RPush(data.UserId)
 | 
			
		||||
	return nil, &res
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user