mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	refactor stable diffusion service, use api key instead of configs
This commit is contained in:
		@@ -17,15 +17,14 @@ type AppConfig struct {
 | 
			
		||||
	Session      Session
 | 
			
		||||
	AdminSession Session
 | 
			
		||||
	ProxyURL     string
 | 
			
		||||
	MysqlDns     string                  // mysql 连接地址
 | 
			
		||||
	StaticDir    string                  // 静态资源目录
 | 
			
		||||
	StaticUrl    string                  // 静态资源 URL
 | 
			
		||||
	Redis        RedisConfig             // redis 连接信息
 | 
			
		||||
	ApiConfig    ApiConfig               // ChatPlus API authorization configs
 | 
			
		||||
	SMS          SMSConfig               // send mobile message config
 | 
			
		||||
	OSS          OSSConfig               // OSS config
 | 
			
		||||
	WeChatBot    bool                    // 是否启用微信机器人
 | 
			
		||||
	SdConfigs    []StableDiffusionConfig // sd AI draw service pool
 | 
			
		||||
	MysqlDns     string      // mysql 连接地址
 | 
			
		||||
	StaticDir    string      // 静态资源目录
 | 
			
		||||
	StaticUrl    string      // 静态资源 URL
 | 
			
		||||
	Redis        RedisConfig // redis 连接信息
 | 
			
		||||
	ApiConfig    ApiConfig   // ChatPlus API authorization configs
 | 
			
		||||
	SMS          SMSConfig   // send mobile message config
 | 
			
		||||
	OSS          OSSConfig   // OSS config
 | 
			
		||||
	WeChatBot    bool        // 是否启用微信机器人
 | 
			
		||||
 | 
			
		||||
	XXLConfig       XXLConfig
 | 
			
		||||
	AlipayConfig    AlipayConfig    // 支付宝支付渠道配置
 | 
			
		||||
@@ -51,27 +50,6 @@ type ApiConfig struct {
 | 
			
		||||
	Token  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MjProxyConfig struct {
 | 
			
		||||
	Enabled bool
 | 
			
		||||
	ApiURL  string // api 地址
 | 
			
		||||
	Mode    string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StableDiffusionConfig struct {
 | 
			
		||||
	Enabled bool
 | 
			
		||||
	Model   string // 模型名称
 | 
			
		||||
	ApiURL  string
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MjPlusConfig struct {
 | 
			
		||||
	Enabled bool   // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
 | 
			
		||||
	ApiURL  string // api 地址
 | 
			
		||||
	Mode    string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AlipayConfig struct {
 | 
			
		||||
	Enabled         bool   // 是否启用该支付通道
 | 
			
		||||
	SandBox         bool   // 是否沙盒环境
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,6 @@ import (
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
@@ -27,14 +26,12 @@ type ConfigHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	levelDB        *store.LevelDB
 | 
			
		||||
	licenseService *service.LicenseService
 | 
			
		||||
	sdServicePool  *sd.ServicePool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, sdPool *sd.ServicePool) *ConfigHandler {
 | 
			
		||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
 | 
			
		||||
	return &ConfigHandler{
 | 
			
		||||
		BaseHandler:    handler.BaseHandler{App: app, DB: db},
 | 
			
		||||
		levelDB:        levelDB,
 | 
			
		||||
		sdServicePool:  sdPool,
 | 
			
		||||
		licenseService: licenseService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,15 +32,15 @@ import (
 | 
			
		||||
type SdJobHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	redis     *redis.Client
 | 
			
		||||
	pool      *sd.ServicePool
 | 
			
		||||
	service   *sd.Service
 | 
			
		||||
	uploader  *oss.UploaderManager
 | 
			
		||||
	snowflake *service.Snowflake
 | 
			
		||||
	leveldb   *store.LevelDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
 | 
			
		||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, service *sd.Service, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
 | 
			
		||||
	return &SdJobHandler{
 | 
			
		||||
		pool:      pool,
 | 
			
		||||
		service:   service,
 | 
			
		||||
		uploader:  manager,
 | 
			
		||||
		snowflake: snowflake,
 | 
			
		||||
		leveldb:   levelDB,
 | 
			
		||||
@@ -68,7 +68,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.pool.Clients.Put(uint(userId), client)
 | 
			
		||||
	h.service.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -79,11 +79,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !h.pool.HasAvailableService() {
 | 
			
		||||
		resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Power < h.App.SysConfig.SdPower {
 | 
			
		||||
		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
 | 
			
		||||
		return false
 | 
			
		||||
@@ -164,14 +159,14 @@ func (h *SdJobHandler) Image(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.SdTask{
 | 
			
		||||
	h.service.PushTask(types.SdTask{
 | 
			
		||||
		Id:     int(job.Id),
 | 
			
		||||
		Type:   types.TaskImage,
 | 
			
		||||
		Params: params,
 | 
			
		||||
		UserId: userId,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	client := h.service.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
@@ -328,11 +323,6 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
 | 
			
		||||
		logger.Error("remove image failed: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte(service.TaskStatusFinished))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										12
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								api/main.go
									
									
									
									
									
								
							@@ -199,13 +199,11 @@ func main() {
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// Stable Diffusion 机器人
 | 
			
		||||
		fx.Provide(sd.NewServicePool),
 | 
			
		||||
		fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
 | 
			
		||||
			pool.InitServices(config.SdConfigs)
 | 
			
		||||
			if pool.HasAvailableService() {
 | 
			
		||||
				pool.CheckTaskNotify()
 | 
			
		||||
				pool.CheckTaskStatus()
 | 
			
		||||
			}
 | 
			
		||||
		fx.Provide(sd.NewService),
 | 
			
		||||
		fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
 | 
			
		||||
			s.Run()
 | 
			
		||||
			s.CheckTaskStatus()
 | 
			
		||||
			s.CheckTaskNotify()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(suno.NewService),
 | 
			
		||||
 
 | 
			
		||||
@@ -50,12 +50,6 @@ type ImageRes struct {
 | 
			
		||||
	Channel string `json:"channel,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrRes struct {
 | 
			
		||||
	Error struct {
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type QueryRes struct {
 | 
			
		||||
	Action  string `json:"action"`
 | 
			
		||||
	Buttons []struct {
 | 
			
		||||
@@ -193,7 +187,6 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
 | 
			
		||||
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
 | 
			
		||||
	if channel != "" {
 | 
			
		||||
		session = session.Where("api_url", channel)
 | 
			
		||||
@@ -215,20 +208,14 @@ func (c *Client) doRequest(body interface{}, apiPath string, channel string) (Im
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		errMsg := err.Error()
 | 
			
		||||
		if r != nil {
 | 
			
		||||
			errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
			logger.Error("请求 API 出错:", string(errStr))
 | 
			
		||||
			errMsg = errMsg + " " + string(errStr)
 | 
			
		||||
		}
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
		errMsg, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update the api key last used time
 | 
			
		||||
 
 | 
			
		||||
@@ -1,128 +0,0 @@
 | 
			
		||||
package sd
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ServicePool struct {
 | 
			
		||||
	services    []*Service
 | 
			
		||||
	taskQueue   *store.RedisQueue
 | 
			
		||||
	notifyQueue *store.RedisQueue
 | 
			
		||||
	db          *gorm.DB
 | 
			
		||||
	Clients     *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
	uploader    *oss.UploaderManager
 | 
			
		||||
	levelDB     *store.LevelDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
 | 
			
		||||
	services := make([]*Service, 0)
 | 
			
		||||
	taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
 | 
			
		||||
	notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
 | 
			
		||||
 | 
			
		||||
	return &ServicePool{
 | 
			
		||||
		taskQueue:   taskQueue,
 | 
			
		||||
		notifyQueue: notifyQueue,
 | 
			
		||||
		services:    services,
 | 
			
		||||
		db:          db,
 | 
			
		||||
		Clients:     types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploader:    manager,
 | 
			
		||||
		levelDB:     levelDB,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
 | 
			
		||||
	// stop old service
 | 
			
		||||
	for _, s := range p.services {
 | 
			
		||||
		s.Stop()
 | 
			
		||||
	}
 | 
			
		||||
	p.services = make([]*Service, 0)
 | 
			
		||||
 | 
			
		||||
	for k, config := range configs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// create sd service
 | 
			
		||||
		name := fmt.Sprintf(" sd-service-%d", k)
 | 
			
		||||
		service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
 | 
			
		||||
		// run sd service
 | 
			
		||||
		go func() {
 | 
			
		||||
			service.Run()
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		p.services = append(p.services, service)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PushTask push a new mj task in to task queue
 | 
			
		||||
func (p *ServicePool) PushTask(task types.SdTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	p.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var message service.NotifyMessage
 | 
			
		||||
			err := p.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			client := p.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if client == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = client.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
 | 
			
		||||
func (p *ServicePool) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.SdJob
 | 
			
		||||
			res := p.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				time.Sleep(5 * time.Second)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 5 分钟还没完成的任务标记为失败
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
 | 
			
		||||
					job.Progress = 101
 | 
			
		||||
					job.ErrMsg = "任务超时"
 | 
			
		||||
					p.db.Updates(&job)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second * 5)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HasAvailableService check if it has available mj service in pool
 | 
			
		||||
func (p *ServicePool) HasAvailableService() bool {
 | 
			
		||||
	return len(p.services) > 0
 | 
			
		||||
}
 | 
			
		||||
@@ -16,7 +16,7 @@ import (
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
@@ -29,79 +29,72 @@ var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type Service struct {
 | 
			
		||||
	httpClient    *req.Client
 | 
			
		||||
	config        types.StableDiffusionConfig
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	notifyQueue   *store.RedisQueue
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	name          string // service name
 | 
			
		||||
	leveldb       *store.LevelDB
 | 
			
		||||
	running       bool // 运行状态
 | 
			
		||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
 | 
			
		||||
	config.ApiURL = strings.TrimRight(config.ApiURL, "/")
 | 
			
		||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		name:          name,
 | 
			
		||||
		config:        config,
 | 
			
		||||
		httpClient:    req.C(),
 | 
			
		||||
		taskQueue:     taskQueue,
 | 
			
		||||
		notifyQueue:   notifyQueue,
 | 
			
		||||
		taskQueue:     store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
 | 
			
		||||
		notifyQueue:   store.NewRedisQueue("StableDiffusion_Queue", redisCli),
 | 
			
		||||
		db:            db,
 | 
			
		||||
		leveldb:       levelDB,
 | 
			
		||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		running:       true,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
 | 
			
		||||
	for s.running {
 | 
			
		||||
		var task types.SdTask
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	logger.Infof("Starting Stable-Diffusion job consumer")
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var task types.SdTask
 | 
			
		||||
			err := s.taskQueue.LPop(&task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		// translate prompt
 | 
			
		||||
		if utils.HasChinese(task.Params.Prompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Params.Prompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			// translate prompt
 | 
			
		||||
			if utils.HasChinese(task.Params.Prompt) {
 | 
			
		||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					task.Params.Prompt = content
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// translate negative prompt
 | 
			
		||||
			if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
 | 
			
		||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					task.Params.NegPrompt = content
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Infof("handle a new Stable-Diffusion task: %+v", task)
 | 
			
		||||
			err = s.Txt2Img(task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("绘画任务执行失败:", err.Error())
 | 
			
		||||
				// update the task progress
 | 
			
		||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
					"progress": service.FailTaskProgress,
 | 
			
		||||
					"err_msg":  err.Error(),
 | 
			
		||||
				})
 | 
			
		||||
				// 通知前端,任务失败
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// translate negative prompt
 | 
			
		||||
		if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Params.NegPrompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
 | 
			
		||||
		err = s.Txt2Img(task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err.Error())
 | 
			
		||||
			// update the task progress
 | 
			
		||||
			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
				"progress": service.FailTaskProgress,
 | 
			
		||||
				"err_msg":  err.Error(),
 | 
			
		||||
			})
 | 
			
		||||
			// 通知前端,任务失败
 | 
			
		||||
			s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Stop() {
 | 
			
		||||
	s.running = false
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Txt2ImgReq 文生图请求实体
 | 
			
		||||
@@ -163,12 +156,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
	}
 | 
			
		||||
	var res Txt2ImgResp
 | 
			
		||||
	var errChan = make(chan error)
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
 | 
			
		||||
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
 | 
			
		||||
	logger.Debugf("send image request to %s", apiURL)
 | 
			
		||||
	// send a request to sd api endpoint
 | 
			
		||||
	go func() {
 | 
			
		||||
		response, err := s.httpClient.R().
 | 
			
		||||
			SetHeader("Authorization", s.config.ApiKey).
 | 
			
		||||
			SetHeader("Authorization", apiKey.Value).
 | 
			
		||||
			SetBody(body).
 | 
			
		||||
			SetSuccessResult(&res).
 | 
			
		||||
			Post(apiURL)
 | 
			
		||||
@@ -181,6 +181,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// update the last used time
 | 
			
		||||
		apiKey.LastUsedAt = time.Now().Unix()
 | 
			
		||||
		s.db.Updates(&apiKey)
 | 
			
		||||
 | 
			
		||||
		// 保存 Base64 图片
 | 
			
		||||
		imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -214,7 +218,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
			_ = s.leveldb.Delete(task.Params.TaskId)
 | 
			
		||||
			return nil
 | 
			
		||||
		default:
 | 
			
		||||
			err, resp := s.checkTaskProgress()
 | 
			
		||||
			err, resp := s.checkTaskProgress(apiKey)
 | 
			
		||||
			// 更新任务进度
 | 
			
		||||
			if err == nil && resp.Progress > 0 {
 | 
			
		||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
 | 
			
		||||
@@ -232,11 +236,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行任务
 | 
			
		||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
 | 
			
		||||
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
 | 
			
		||||
	var res TaskProgressResp
 | 
			
		||||
	response, err := s.httpClient.R().
 | 
			
		||||
		SetHeader("Authorization", s.config.ApiKey).
 | 
			
		||||
		SetHeader("Authorization", apiKey.Value).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Get(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -248,3 +252,54 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
 | 
			
		||||
 | 
			
		||||
	return nil, &res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) PushTask(task types.SdTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var message service.NotifyMessage
 | 
			
		||||
			err := s.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			client := s.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if client == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = client.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
 | 
			
		||||
func (s *Service) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.SdJob
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				time.Sleep(5 * time.Second)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 5 分钟还没完成的任务标记为失败
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
 | 
			
		||||
					job.Progress = service.FailTaskProgress
 | 
			
		||||
					job.ErrMsg = "任务超时"
 | 
			
		||||
					s.db.Updates(&job)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second * 5)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user