mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
refactor stable diffusion service, use api key instead of configs
This commit is contained in:
parent
6a8b4ee2f1
commit
1d0006ce59
@ -17,15 +17,14 @@ type AppConfig struct {
|
|||||||
Session Session
|
Session Session
|
||||||
AdminSession Session
|
AdminSession Session
|
||||||
ProxyURL string
|
ProxyURL string
|
||||||
MysqlDns string // mysql 连接地址
|
MysqlDns string // mysql 连接地址
|
||||||
StaticDir string // 静态资源目录
|
StaticDir string // 静态资源目录
|
||||||
StaticUrl string // 静态资源 URL
|
StaticUrl string // 静态资源 URL
|
||||||
Redis RedisConfig // redis 连接信息
|
Redis RedisConfig // redis 连接信息
|
||||||
ApiConfig ApiConfig // ChatPlus API authorization configs
|
ApiConfig ApiConfig // ChatPlus API authorization configs
|
||||||
SMS SMSConfig // send mobile message config
|
SMS SMSConfig // send mobile message config
|
||||||
OSS OSSConfig // OSS config
|
OSS OSSConfig // OSS config
|
||||||
WeChatBot bool // 是否启用微信机器人
|
WeChatBot bool // 是否启用微信机器人
|
||||||
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
|
||||||
|
|
||||||
XXLConfig XXLConfig
|
XXLConfig XXLConfig
|
||||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||||
@ -51,27 +50,6 @@ type ApiConfig struct {
|
|||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MjProxyConfig struct {
|
|
||||||
Enabled bool
|
|
||||||
ApiURL string // api 地址
|
|
||||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
|
||||||
ApiKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type StableDiffusionConfig struct {
|
|
||||||
Enabled bool
|
|
||||||
Model string // 模型名称
|
|
||||||
ApiURL string
|
|
||||||
ApiKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type MjPlusConfig struct {
|
|
||||||
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
|
||||||
ApiURL string // api 地址
|
|
||||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
|
||||||
ApiKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type AlipayConfig struct {
|
type AlipayConfig struct {
|
||||||
Enabled bool // 是否启用该支付通道
|
Enabled bool // 是否启用该支付通道
|
||||||
SandBox bool // 是否沙盒环境
|
SandBox bool // 是否沙盒环境
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
"geekai/handler"
|
"geekai/handler"
|
||||||
"geekai/service"
|
"geekai/service"
|
||||||
"geekai/service/sd"
|
|
||||||
"geekai/store"
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
@ -27,14 +26,12 @@ type ConfigHandler struct {
|
|||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
levelDB *store.LevelDB
|
levelDB *store.LevelDB
|
||||||
licenseService *service.LicenseService
|
licenseService *service.LicenseService
|
||||||
sdServicePool *sd.ServicePool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, sdPool *sd.ServicePool) *ConfigHandler {
|
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
|
||||||
return &ConfigHandler{
|
return &ConfigHandler{
|
||||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||||
levelDB: levelDB,
|
levelDB: levelDB,
|
||||||
sdServicePool: sdPool,
|
|
||||||
licenseService: licenseService,
|
licenseService: licenseService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,15 +32,15 @@ import (
|
|||||||
type SdJobHandler struct {
|
type SdJobHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
pool *sd.ServicePool
|
service *sd.Service
|
||||||
uploader *oss.UploaderManager
|
uploader *oss.UploaderManager
|
||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
leveldb *store.LevelDB
|
leveldb *store.LevelDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
|
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, service *sd.Service, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
|
||||||
return &SdJobHandler{
|
return &SdJobHandler{
|
||||||
pool: pool,
|
service: service,
|
||||||
uploader: manager,
|
uploader: manager,
|
||||||
snowflake: snowflake,
|
snowflake: snowflake,
|
||||||
leveldb: levelDB,
|
leveldb: levelDB,
|
||||||
@ -68,7 +68,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
client := types.NewWsClient(ws)
|
||||||
h.pool.Clients.Put(uint(userId), client)
|
h.service.Clients.Put(uint(userId), client)
|
||||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,11 +79,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !h.pool.HasAvailableService() {
|
|
||||||
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Power < h.App.SysConfig.SdPower {
|
if user.Power < h.App.SysConfig.SdPower {
|
||||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||||
return false
|
return false
|
||||||
@ -164,14 +159,14 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.pool.PushTask(types.SdTask{
|
h.service.PushTask(types.SdTask{
|
||||||
Id: int(job.Id),
|
Id: int(job.Id),
|
||||||
Type: types.TaskImage,
|
Type: types.TaskImage,
|
||||||
Params: params,
|
Params: params,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
})
|
})
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
client := h.service.Clients.Get(uint(job.UserId))
|
||||||
if client != nil {
|
if client != nil {
|
||||||
_ = client.Send([]byte("Task Updated"))
|
_ = client.Send([]byte("Task Updated"))
|
||||||
}
|
}
|
||||||
@ -328,11 +323,6 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
|||||||
logger.Error("remove image failed: ", err)
|
logger.Error("remove image failed: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := h.pool.Clients.Get(uint(job.UserId))
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte(service.TaskStatusFinished))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
12
api/main.go
12
api/main.go
@ -199,13 +199,11 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
|
|
||||||
// Stable Diffusion 机器人
|
// Stable Diffusion 机器人
|
||||||
fx.Provide(sd.NewServicePool),
|
fx.Provide(sd.NewService),
|
||||||
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
|
fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
|
||||||
pool.InitServices(config.SdConfigs)
|
s.Run()
|
||||||
if pool.HasAvailableService() {
|
s.CheckTaskStatus()
|
||||||
pool.CheckTaskNotify()
|
s.CheckTaskNotify()
|
||||||
pool.CheckTaskStatus()
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
fx.Provide(suno.NewService),
|
fx.Provide(suno.NewService),
|
||||||
|
@ -50,12 +50,6 @@ type ImageRes struct {
|
|||||||
Channel string `json:"channel,omitempty"`
|
Channel string `json:"channel,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ErrRes struct {
|
|
||||||
Error struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueryRes struct {
|
type QueryRes struct {
|
||||||
Action string `json:"action"`
|
Action string `json:"action"`
|
||||||
Buttons []struct {
|
Buttons []struct {
|
||||||
@ -193,7 +187,6 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
|||||||
|
|
||||||
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
|
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
var errRes ErrRes
|
|
||||||
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
|
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
|
||||||
if channel != "" {
|
if channel != "" {
|
||||||
session = session.Where("api_url", channel)
|
session = session.Where("api_url", channel)
|
||||||
@ -215,20 +208,14 @@ func (c *Client) doRequest(body interface{}, apiPath string, channel string) (Im
|
|||||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
SetBody(body).
|
SetBody(body).
|
||||||
SetSuccessResult(&res).
|
SetSuccessResult(&res).
|
||||||
SetErrorResult(&errRes).
|
|
||||||
Post(apiURL)
|
Post(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := err.Error()
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
if r != nil {
|
|
||||||
errStr, _ := io.ReadAll(r.Body)
|
|
||||||
logger.Error("请求 API 出错:", string(errStr))
|
|
||||||
errMsg = errMsg + " " + string(errStr)
|
|
||||||
}
|
|
||||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.IsErrorState() {
|
if r.IsErrorState() {
|
||||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
errMsg, _ := io.ReadAll(r.Body)
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
|
||||||
}
|
}
|
||||||
|
|
||||||
// update the api key last used time
|
// update the api key last used time
|
||||||
|
@ -1,128 +0,0 @@
|
|||||||
package sd
|
|
||||||
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
|
||||||
// * Use of this source code is governed by a Apache-2.0 license
|
|
||||||
// * that can be found in the LICENSE file.
|
|
||||||
// * @Author yangjian102621@163.com
|
|
||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"geekai/core/types"
|
|
||||||
"geekai/service"
|
|
||||||
"geekai/service/oss"
|
|
||||||
"geekai/store"
|
|
||||||
"geekai/store/model"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServicePool struct {
|
|
||||||
services []*Service
|
|
||||||
taskQueue *store.RedisQueue
|
|
||||||
notifyQueue *store.RedisQueue
|
|
||||||
db *gorm.DB
|
|
||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
|
||||||
uploader *oss.UploaderManager
|
|
||||||
levelDB *store.LevelDB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
|
|
||||||
services := make([]*Service, 0)
|
|
||||||
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
|
||||||
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
|
||||||
|
|
||||||
return &ServicePool{
|
|
||||||
taskQueue: taskQueue,
|
|
||||||
notifyQueue: notifyQueue,
|
|
||||||
services: services,
|
|
||||||
db: db,
|
|
||||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
|
||||||
uploader: manager,
|
|
||||||
levelDB: levelDB,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
|
|
||||||
// stop old service
|
|
||||||
for _, s := range p.services {
|
|
||||||
s.Stop()
|
|
||||||
}
|
|
||||||
p.services = make([]*Service, 0)
|
|
||||||
|
|
||||||
for k, config := range configs {
|
|
||||||
if config.Enabled == false {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// create sd service
|
|
||||||
name := fmt.Sprintf(" sd-service-%d", k)
|
|
||||||
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
|
|
||||||
// run sd service
|
|
||||||
go func() {
|
|
||||||
service.Run()
|
|
||||||
}()
|
|
||||||
|
|
||||||
p.services = append(p.services, service)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushTask push a new mj task in to task queue
|
|
||||||
func (p *ServicePool) PushTask(task types.SdTask) {
|
|
||||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
|
||||||
p.taskQueue.RPush(task)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ServicePool) CheckTaskNotify() {
|
|
||||||
go func() {
|
|
||||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
|
||||||
for {
|
|
||||||
var message service.NotifyMessage
|
|
||||||
err := p.notifyQueue.LPop(&message)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
client := p.Clients.Get(uint(message.UserId))
|
|
||||||
if client == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = client.Send([]byte(message.Message))
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
|
||||||
func (p *ServicePool) CheckTaskStatus() {
|
|
||||||
go func() {
|
|
||||||
logger.Info("Running Stable-Diffusion task status checking ...")
|
|
||||||
for {
|
|
||||||
var jobs []model.SdJob
|
|
||||||
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
|
||||||
if res.Error != nil {
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, job := range jobs {
|
|
||||||
// 5 分钟还没完成的任务标记为失败
|
|
||||||
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
|
|
||||||
job.Progress = 101
|
|
||||||
job.ErrMsg = "任务超时"
|
|
||||||
p.db.Updates(&job)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 5)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasAvailableService check if it has available mj service in pool
|
|
||||||
func (p *ServicePool) HasAvailableService() bool {
|
|
||||||
return len(p.services) > 0
|
|
||||||
}
|
|
@ -16,7 +16,7 @@ import (
|
|||||||
"geekai/store"
|
"geekai/store"
|
||||||
"geekai/store/model"
|
"geekai/store/model"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"strings"
|
"github.com/go-redis/redis/v8"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
@ -29,79 +29,72 @@ var logger = logger2.GetLogger()
|
|||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
httpClient *req.Client
|
httpClient *req.Client
|
||||||
config types.StableDiffusionConfig
|
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
name string // service name
|
|
||||||
leveldb *store.LevelDB
|
leveldb *store.LevelDB
|
||||||
running bool // 运行状态
|
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
|
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
|
||||||
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
|
||||||
config: config,
|
|
||||||
httpClient: req.C(),
|
httpClient: req.C(),
|
||||||
taskQueue: taskQueue,
|
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||||
db: db,
|
db: db,
|
||||||
leveldb: levelDB,
|
leveldb: levelDB,
|
||||||
|
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
running: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
|
logger.Infof("Starting Stable-Diffusion job consumer")
|
||||||
for s.running {
|
go func() {
|
||||||
var task types.SdTask
|
for {
|
||||||
err := s.taskQueue.LPop(&task)
|
var task types.SdTask
|
||||||
if err != nil {
|
err := s.taskQueue.LPop(&task)
|
||||||
logger.Errorf("taking task with error: %v", err)
|
if err != nil {
|
||||||
continue
|
logger.Errorf("taking task with error: %v", err)
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// translate prompt
|
// translate prompt
|
||||||
if utils.HasChinese(task.Params.Prompt) {
|
if utils.HasChinese(task.Params.Prompt) {
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
task.Params.Prompt = content
|
task.Params.Prompt = content
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf("error with translate prompt: %v", err)
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// translate negative prompt
|
||||||
|
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||||
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
||||||
|
if err == nil {
|
||||||
|
task.Params.NegPrompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("handle a new Stable-Diffusion task: %+v", task)
|
||||||
|
err = s.Txt2Img(task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("绘画任务执行失败:", err.Error())
|
||||||
|
// update the task progress
|
||||||
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": service.FailTaskProgress,
|
||||||
|
"err_msg": err.Error(),
|
||||||
|
})
|
||||||
|
// 通知前端,任务失败
|
||||||
|
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
// translate negative prompt
|
|
||||||
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
|
||||||
if err == nil {
|
|
||||||
task.Params.NegPrompt = content
|
|
||||||
} else {
|
|
||||||
logger.Warnf("error with translate prompt: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
|
||||||
err = s.Txt2Img(task)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("绘画任务执行失败:", err.Error())
|
|
||||||
// update the task progress
|
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": service.FailTaskProgress,
|
|
||||||
"err_msg": err.Error(),
|
|
||||||
})
|
|
||||||
// 通知前端,任务失败
|
|
||||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Stop() {
|
|
||||||
s.running = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Txt2ImgReq 文生图请求实体
|
// Txt2ImgReq 文生图请求实体
|
||||||
@ -163,12 +156,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
}
|
}
|
||||||
var res Txt2ImgResp
|
var res Txt2ImgResp
|
||||||
var errChan = make(chan error)
|
var errChan = make(chan error)
|
||||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
|
||||||
logger.Debugf("send image request to %s", apiURL)
|
logger.Debugf("send image request to %s", apiURL)
|
||||||
// send a request to sd api endpoint
|
// send a request to sd api endpoint
|
||||||
go func() {
|
go func() {
|
||||||
response, err := s.httpClient.R().
|
response, err := s.httpClient.R().
|
||||||
SetHeader("Authorization", s.config.ApiKey).
|
SetHeader("Authorization", apiKey.Value).
|
||||||
SetBody(body).
|
SetBody(body).
|
||||||
SetSuccessResult(&res).
|
SetSuccessResult(&res).
|
||||||
Post(apiURL)
|
Post(apiURL)
|
||||||
@ -181,6 +181,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update the last used time
|
||||||
|
apiKey.LastUsedAt = time.Now().Unix()
|
||||||
|
s.db.Updates(&apiKey)
|
||||||
|
|
||||||
// 保存 Base64 图片
|
// 保存 Base64 图片
|
||||||
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -214,7 +218,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
err, resp := s.checkTaskProgress()
|
err, resp := s.checkTaskProgress(apiKey)
|
||||||
// 更新任务进度
|
// 更新任务进度
|
||||||
if err == nil && resp.Progress > 0 {
|
if err == nil && resp.Progress > 0 {
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||||
@ -232,11 +236,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 执行任务
|
// 执行任务
|
||||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
|
||||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
|
||||||
var res TaskProgressResp
|
var res TaskProgressResp
|
||||||
response, err := s.httpClient.R().
|
response, err := s.httpClient.R().
|
||||||
SetHeader("Authorization", s.config.ApiKey).
|
SetHeader("Authorization", apiKey.Value).
|
||||||
SetSuccessResult(&res).
|
SetSuccessResult(&res).
|
||||||
Get(apiURL)
|
Get(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -248,3 +252,54 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
|||||||
|
|
||||||
return nil, &res
|
return nil, &res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) PushTask(task types.SdTask) {
|
||||||
|
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||||
|
s.taskQueue.RPush(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) CheckTaskNotify() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||||
|
for {
|
||||||
|
var message service.NotifyMessage
|
||||||
|
err := s.notifyQueue.LPop(&message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
client := s.Clients.Get(uint(message.UserId))
|
||||||
|
if client == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = client.Send([]byte(message.Message))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||||
|
func (s *Service) CheckTaskStatus() {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||||
|
for {
|
||||||
|
var jobs []model.SdJob
|
||||||
|
res := s.db.Where("progress < ?", 100).Find(&jobs)
|
||||||
|
if res.Error != nil {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, job := range jobs {
|
||||||
|
// 5 分钟还没完成的任务标记为失败
|
||||||
|
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
|
||||||
|
job.Progress = service.FailTaskProgress
|
||||||
|
job.ErrMsg = "任务超时"
|
||||||
|
s.db.Updates(&job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 5)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user