refactor stable diffusion service, use api key instead of configs

This commit is contained in:
RockYang 2024-08-07 17:30:59 +08:00
parent 6a8b4ee2f1
commit 1d0006ce59
7 changed files with 139 additions and 262 deletions

View File

@ -17,15 +17,14 @@ type AppConfig struct {
Session Session Session Session
AdminSession Session AdminSession Session
ProxyURL string ProxyURL string
MysqlDns string // mysql 连接地址 MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录 StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息 Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config OSS OSSConfig // OSS config
WeChatBot bool // 是否启用微信机器人 WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig XXLConfig XXLConfig
AlipayConfig AlipayConfig // 支付宝支付渠道配置 AlipayConfig AlipayConfig // 支付宝支付渠道配置
@ -51,27 +50,6 @@ type ApiConfig struct {
Token string Token string
} }
type MjProxyConfig struct {
Enabled bool
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type StableDiffusionConfig struct {
Enabled bool
Model string // 模型名称
ApiURL string
ApiKey string
}
type MjPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
ApiURL string // api 地址
Mode string // 绘画模式可选值fast/turbo/relax
ApiKey string
}
type AlipayConfig struct { type AlipayConfig struct {
Enabled bool // 是否启用该支付通道 Enabled bool // 是否启用该支付通道
SandBox bool // 是否沙盒环境 SandBox bool // 是否沙盒环境

View File

@ -12,7 +12,6 @@ import (
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/service" "geekai/service"
"geekai/service/sd"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
@ -27,14 +26,12 @@ type ConfigHandler struct {
handler.BaseHandler handler.BaseHandler
levelDB *store.LevelDB levelDB *store.LevelDB
licenseService *service.LicenseService licenseService *service.LicenseService
sdServicePool *sd.ServicePool
} }
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, sdPool *sd.ServicePool) *ConfigHandler { func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
return &ConfigHandler{ return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db}, BaseHandler: handler.BaseHandler{App: app, DB: db},
levelDB: levelDB, levelDB: levelDB,
sdServicePool: sdPool,
licenseService: licenseService, licenseService: licenseService,
} }
} }

View File

@ -32,15 +32,15 @@ import (
type SdJobHandler struct { type SdJobHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
pool *sd.ServicePool service *sd.Service
uploader *oss.UploaderManager uploader *oss.UploaderManager
snowflake *service.Snowflake snowflake *service.Snowflake
leveldb *store.LevelDB leveldb *store.LevelDB
} }
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { func NewSdJobHandler(app *core.AppServer, db *gorm.DB, service *sd.Service, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
return &SdJobHandler{ return &SdJobHandler{
pool: pool, service: service,
uploader: manager, uploader: manager,
snowflake: snowflake, snowflake: snowflake,
leveldb: levelDB, leveldb: levelDB,
@ -68,7 +68,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
} }
client := types.NewWsClient(ws) client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client) h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
} }
@ -79,11 +79,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
return false
}
if user.Power < h.App.SysConfig.SdPower { if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false return false
@ -164,14 +159,14 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return return
} }
h.pool.PushTask(types.SdTask{ h.service.PushTask(types.SdTask{
Id: int(job.Id), Id: int(job.Id),
Type: types.TaskImage, Type: types.TaskImage,
Params: params, Params: params,
UserId: userId, UserId: userId,
}) })
client := h.pool.Clients.Get(uint(job.UserId)) client := h.service.Clients.Get(uint(job.UserId))
if client != nil { if client != nil {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte("Task Updated"))
} }
@ -328,11 +323,6 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte(service.TaskStatusFinished))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }

View File

@ -199,13 +199,11 @@ func main() {
}), }),
// Stable Diffusion 机器人 // Stable Diffusion 机器人
fx.Provide(sd.NewServicePool), fx.Provide(sd.NewService),
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) { fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
pool.InitServices(config.SdConfigs) s.Run()
if pool.HasAvailableService() { s.CheckTaskStatus()
pool.CheckTaskNotify() s.CheckTaskNotify()
pool.CheckTaskStatus()
}
}), }),
fx.Provide(suno.NewService), fx.Provide(suno.NewService),

View File

@ -50,12 +50,6 @@ type ImageRes struct {
Channel string `json:"channel,omitempty"` Channel string `json:"channel,omitempty"`
} }
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
type QueryRes struct { type QueryRes struct {
Action string `json:"action"` Action string `json:"action"`
Buttons []struct { Buttons []struct {
@ -193,7 +187,6 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) { func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
var res ImageRes var res ImageRes
var errRes ErrRes
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true) session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
if channel != "" { if channel != "" {
session = session.Where("api_url", channel) session = session.Where("api_url", channel)
@ -215,20 +208,14 @@ func (c *Client) doRequest(body interface{}, apiPath string, channel string) (Im
SetHeader("Authorization", "Bearer "+apiKey.Value). SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(body). SetBody(body).
SetSuccessResult(&res). SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL) Post(apiURL)
if err != nil { if err != nil {
errMsg := err.Error() return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
if r != nil {
errStr, _ := io.ReadAll(r.Body)
logger.Error("请求 API 出错:", string(errStr))
errMsg = errMsg + " " + string(errStr)
}
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg)
} }
if r.IsErrorState() { if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) errMsg, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
} }
// update the api key last used time // update the api key last used time

View File

@ -1,128 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
uploader *oss.UploaderManager
levelDB *store.LevelDB
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
uploader: manager,
levelDB: levelDB,
}
}
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for k, config := range configs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf(" sd-service-%d", k)
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
// run sd service
go func() {
service.Run()
}()
p.services = append(p.services, service)
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message service.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := p.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (p *ServicePool) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
job.Progress = 101
job.ErrMsg = "任务超时"
p.db.Updates(&job)
}
}
time.Sleep(time.Second * 5)
}
}()
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@ -16,7 +16,7 @@ import (
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"strings" "github.com/go-redis/redis/v8"
"time" "time"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
@ -29,79 +29,72 @@ var logger = logger2.GetLogger()
type Service struct { type Service struct {
httpClient *req.Client httpClient *req.Client
config types.StableDiffusionConfig
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
name string // service name
leveldb *store.LevelDB leveldb *store.LevelDB
running bool // 运行状态 Clients *types.LMap[uint, *types.WsClient] // UserId => Client
} }
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
return &Service{ return &Service{
name: name,
config: config,
httpClient: req.C(), httpClient: req.C(),
taskQueue: taskQueue, taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: notifyQueue, notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db, db: db,
leveldb: levelDB, leveldb: levelDB,
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager, uploadManager: manager,
running: true,
} }
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name) logger.Infof("Starting Stable-Diffusion job consumer")
for s.running { go func() {
var task types.SdTask for {
err := s.taskQueue.LPop(&task) var task types.SdTask
if err != nil { err := s.taskQueue.LPop(&task)
logger.Errorf("taking task with error: %v", err) if err != nil {
continue logger.Errorf("taking task with error: %v", err)
} continue
}
// translate prompt // translate prompt
if utils.HasChinese(task.Params.Prompt) { if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
if err == nil { if err == nil {
task.Params.Prompt = content task.Params.Prompt = content
} else { } else {
logger.Warnf("error with translate prompt: %v", err) logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("handle a new Stable-Diffusion task: %+v", task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue
} }
} }
}()
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue
}
}
}
func (s *Service) Stop() {
s.running = false
} }
// Txt2ImgReq 文生图请求实体 // Txt2ImgReq 文生图请求实体
@ -163,12 +156,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
} }
var res Txt2ImgResp var res Txt2ImgResp
var errChan = make(chan error) var errChan = make(chan error)
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
var apiKey model.ApiKey
err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
}
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
logger.Debugf("send image request to %s", apiURL) logger.Debugf("send image request to %s", apiURL)
// send a request to sd api endpoint // send a request to sd api endpoint
go func() { go func() {
response, err := s.httpClient.R(). response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey). SetHeader("Authorization", apiKey.Value).
SetBody(body). SetBody(body).
SetSuccessResult(&res). SetSuccessResult(&res).
Post(apiURL) Post(apiURL)
@ -181,6 +181,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
return return
} }
// update the last used time
apiKey.LastUsedAt = time.Now().Unix()
s.db.Updates(&apiKey)
// 保存 Base64 图片 // 保存 Base64 图片
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0]) imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
if err != nil { if err != nil {
@ -214,7 +218,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
_ = s.leveldb.Delete(task.Params.TaskId) _ = s.leveldb.Delete(task.Params.TaskId)
return nil return nil
default: default:
err, resp := s.checkTaskProgress() err, resp := s.checkTaskProgress(apiKey)
// 更新任务进度 // 更新任务进度
if err == nil && resp.Progress > 0 { if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
@ -232,11 +236,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
} }
// 执行任务 // 执行任务
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
var res TaskProgressResp var res TaskProgressResp
response, err := s.httpClient.R(). response, err := s.httpClient.R().
SetHeader("Authorization", s.config.ApiKey). SetHeader("Authorization", apiKey.Value).
SetSuccessResult(&res). SetSuccessResult(&res).
Get(apiURL) Get(apiURL)
if err != nil { if err != nil {
@ -248,3 +252,54 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
return nil, &res return nil, &res
} }
func (s *Service) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
}
}
time.Sleep(time.Second * 5)
}
}()
}