mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
refactor stable diffusion service, use api key instead of configs
This commit is contained in:
parent
6a8b4ee2f1
commit
1d0006ce59
@ -25,7 +25,6 @@ type AppConfig struct {
|
||||
SMS SMSConfig // send mobile message config
|
||||
OSS OSSConfig // OSS config
|
||||
WeChatBot bool // 是否启用微信机器人
|
||||
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
||||
|
||||
XXLConfig XXLConfig
|
||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||
@ -51,27 +50,6 @@ type ApiConfig struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
type MjProxyConfig struct {
|
||||
Enabled bool
|
||||
ApiURL string // api 地址
|
||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
type StableDiffusionConfig struct {
|
||||
Enabled bool
|
||||
Model string // 模型名称
|
||||
ApiURL string
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
type MjPlusConfig struct {
|
||||
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
||||
ApiURL string // api 地址
|
||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
type AlipayConfig struct {
|
||||
Enabled bool // 是否启用该支付通道
|
||||
SandBox bool // 是否沙盒环境
|
||||
|
@ -12,7 +12,6 @@ import (
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
@ -27,14 +26,12 @@ type ConfigHandler struct {
|
||||
handler.BaseHandler
|
||||
levelDB *store.LevelDB
|
||||
licenseService *service.LicenseService
|
||||
sdServicePool *sd.ServicePool
|
||||
}
|
||||
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, sdPool *sd.ServicePool) *ConfigHandler {
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
levelDB: levelDB,
|
||||
sdServicePool: sdPool,
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
@ -32,15 +32,15 @@ import (
|
||||
type SdJobHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
pool *sd.ServicePool
|
||||
service *sd.Service
|
||||
uploader *oss.UploaderManager
|
||||
snowflake *service.Snowflake
|
||||
leveldb *store.LevelDB
|
||||
}
|
||||
|
||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
|
||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, service *sd.Service, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
|
||||
return &SdJobHandler{
|
||||
pool: pool,
|
||||
service: service,
|
||||
uploader: manager,
|
||||
snowflake: snowflake,
|
||||
leveldb: levelDB,
|
||||
@ -68,7 +68,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.pool.Clients.Put(uint(userId), client)
|
||||
h.service.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
@ -79,11 +79,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if !h.pool.HasAvailableService() {
|
||||
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
|
||||
return false
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.SdPower {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
return false
|
||||
@ -164,14 +159,14 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.SdTask{
|
||||
h.service.PushTask(types.SdTask{
|
||||
Id: int(job.Id),
|
||||
Type: types.TaskImage,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
@ -328,11 +323,6 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte(service.TaskStatusFinished))
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
|
12
api/main.go
12
api/main.go
@ -199,13 +199,11 @@ func main() {
|
||||
}),
|
||||
|
||||
// Stable Diffusion 机器人
|
||||
fx.Provide(sd.NewServicePool),
|
||||
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
|
||||
pool.InitServices(config.SdConfigs)
|
||||
if pool.HasAvailableService() {
|
||||
pool.CheckTaskNotify()
|
||||
pool.CheckTaskStatus()
|
||||
}
|
||||
fx.Provide(sd.NewService),
|
||||
fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
|
||||
s.Run()
|
||||
s.CheckTaskStatus()
|
||||
s.CheckTaskNotify()
|
||||
}),
|
||||
|
||||
fx.Provide(suno.NewService),
|
||||
|
@ -50,12 +50,6 @@ type ImageRes struct {
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
type ErrRes struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
type QueryRes struct {
|
||||
Action string `json:"action"`
|
||||
Buttons []struct {
|
||||
@ -193,7 +187,6 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
||||
|
||||
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
|
||||
if channel != "" {
|
||||
session = session.Where("api_url", channel)
|
||||
@ -215,20 +208,14 @@ func (c *Client) doRequest(body interface{}, apiPath string, channel string) (Im
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if r != nil {
|
||||
errStr, _ := io.ReadAll(r.Body)
|
||||
logger.Error("请求 API 出错:", string(errStr))
|
||||
errMsg = errMsg + " " + string(errStr)
|
||||
}
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg)
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
errMsg, _ := io.ReadAll(r.Body)
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
|
||||
}
|
||||
|
||||
// update the api key last used time
|
||||
|
@ -1,128 +0,0 @@
|
||||
package sd
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ServicePool struct {
|
||||
services []*Service
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
uploader *oss.UploaderManager
|
||||
levelDB *store.LevelDB
|
||||
}
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
|
||||
services := make([]*Service, 0)
|
||||
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
|
||||
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
|
||||
|
||||
return &ServicePool{
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
services: services,
|
||||
db: db,
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploader: manager,
|
||||
levelDB: levelDB,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
|
||||
// stop old service
|
||||
for _, s := range p.services {
|
||||
s.Stop()
|
||||
}
|
||||
p.services = make([]*Service, 0)
|
||||
|
||||
for k, config := range configs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
|
||||
// create sd service
|
||||
name := fmt.Sprintf(" sd-service-%d", k)
|
||||
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
|
||||
// run sd service
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
|
||||
p.services = append(p.services, service)
|
||||
}
|
||||
}
|
||||
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (p *ServicePool) PushTask(task types.SdTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
p.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (p *ServicePool) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := p.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := p.Clients.Get(uint(message.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||
func (p *ServicePool) CheckTaskStatus() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||
for {
|
||||
var jobs []model.SdJob
|
||||
res := p.db.Where("progress < ?", 100).Find(&jobs)
|
||||
if res.Error != nil {
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
// 5 分钟还没完成的任务标记为失败
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
|
||||
job.Progress = 101
|
||||
job.ErrMsg = "任务超时"
|
||||
p.db.Updates(&job)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// HasAvailableService check if it has available mj service in pool
|
||||
func (p *ServicePool) HasAvailableService() bool {
|
||||
return len(p.services) > 0
|
||||
}
|
@ -16,7 +16,7 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
@ -29,34 +29,30 @@ var logger = logger2.GetLogger()
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
config types.StableDiffusionConfig
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
name string // service name
|
||||
leveldb *store.LevelDB
|
||||
running bool // 运行状态
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
}
|
||||
|
||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
|
||||
config.ApiURL = strings.TrimRight(config.ApiURL, "/")
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
|
||||
return &Service{
|
||||
name: name,
|
||||
config: config,
|
||||
httpClient: req.C(),
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||
db: db,
|
||||
leveldb: levelDB,
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploadManager: manager,
|
||||
running: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
|
||||
for s.running {
|
||||
logger.Infof("Starting Stable-Diffusion job consumer")
|
||||
go func() {
|
||||
for {
|
||||
var task types.SdTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
@ -84,7 +80,7 @@ func (s *Service) Run() {
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
|
||||
logger.Infof("handle a new Stable-Diffusion task: %+v", task)
|
||||
err = s.Txt2Img(task)
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err.Error())
|
||||
@ -98,10 +94,7 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Stop() {
|
||||
s.running = false
|
||||
}()
|
||||
}
|
||||
|
||||
// Txt2ImgReq 文生图请求实体
|
||||
@ -163,12 +156,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
}
|
||||
var res Txt2ImgResp
|
||||
var errChan = make(chan error)
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||
|
||||
var apiKey model.ApiKey
|
||||
err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
|
||||
logger.Debugf("send image request to %s", apiURL)
|
||||
// send a request to sd api endpoint
|
||||
go func() {
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetHeader("Authorization", apiKey.Value).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
@ -181,6 +181,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
return
|
||||
}
|
||||
|
||||
// update the last used time
|
||||
apiKey.LastUsedAt = time.Now().Unix()
|
||||
s.db.Updates(&apiKey)
|
||||
|
||||
// 保存 Base64 图片
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
|
||||
if err != nil {
|
||||
@ -214,7 +218,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress()
|
||||
err, resp := s.checkTaskProgress(apiKey)
|
||||
// 更新任务进度
|
||||
if err == nil && resp.Progress > 0 {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
@ -232,11 +236,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
|
||||
var res TaskProgressResp
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetHeader("Authorization", apiKey.Value).
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
if err != nil {
|
||||
@ -248,3 +252,54 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||
|
||||
return nil, &res
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||
func (s *Service) CheckTaskStatus() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task status checking ...")
|
||||
for {
|
||||
var jobs []model.SdJob
|
||||
res := s.db.Where("progress < ?", 100).Find(&jobs)
|
||||
if res.Error != nil {
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
// 5 分钟还没完成的任务标记为失败
|
||||
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
|
||||
job.Progress = service.FailTaskProgress
|
||||
job.ErrMsg = "任务超时"
|
||||
s.db.Updates(&job)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user