From 1d0006ce59881d3f47c7631c512b1a335804473d Mon Sep 17 00:00:00 2001 From: RockYang Date: Wed, 7 Aug 2024 17:30:59 +0800 Subject: [PATCH] refactor stable diffusion service, use api key instead of configs --- api/core/types/config.go | 38 ++---- api/handler/admin/config_handler.go | 5 +- api/handler/sd_handler.go | 22 +--- api/main.go | 12 +- api/service/mj/client.go | 19 +-- api/service/sd/pool.go | 128 -------------------- api/service/sd/service.go | 177 ++++++++++++++++++---------- 7 files changed, 139 insertions(+), 262 deletions(-) delete mode 100644 api/service/sd/pool.go diff --git a/api/core/types/config.go b/api/core/types/config.go index 984a6982..a9d8ea4c 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -17,15 +17,14 @@ type AppConfig struct { Session Session AdminSession Session ProxyURL string - MysqlDns string // mysql 连接地址 - StaticDir string // 静态资源目录 - StaticUrl string // 静态资源 URL - Redis RedisConfig // redis 连接信息 - ApiConfig ApiConfig // ChatPlus API authorization configs - SMS SMSConfig // send mobile message config - OSS OSSConfig // OSS config - WeChatBot bool // 是否启用微信机器人 - SdConfigs []StableDiffusionConfig // sd AI draw service pool + MysqlDns string // mysql 连接地址 + StaticDir string // 静态资源目录 + StaticUrl string // 静态资源 URL + Redis RedisConfig // redis 连接信息 + ApiConfig ApiConfig // ChatPlus API authorization configs + SMS SMSConfig // send mobile message config + OSS OSSConfig // OSS config + WeChatBot bool // 是否启用微信机器人 XXLConfig XXLConfig AlipayConfig AlipayConfig // 支付宝支付渠道配置 @@ -51,27 +50,6 @@ type ApiConfig struct { Token string } -type MjProxyConfig struct { - Enabled bool - ApiURL string // api 地址 - Mode string // 绘画模式,可选值:fast/turbo/relax - ApiKey string -} - -type StableDiffusionConfig struct { - Enabled bool - Model string // 模型名称 - ApiURL string - ApiKey string -} - -type MjPlusConfig struct { - Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 - ApiURL string // api 地址 - Mode string // 绘画模式,可选值:fast/turbo/relax - ApiKey string -} - type AlipayConfig struct { Enabled bool // 是否启用该支付通道 SandBox bool // 是否沙盒环境 diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index b3d22705..4a6aa690 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -12,7 +12,6 @@ import ( "geekai/core/types" "geekai/handler" "geekai/service" - "geekai/service/sd" "geekai/store" "geekai/store/model" "geekai/utils" @@ -27,14 +26,12 @@ type ConfigHandler struct { handler.BaseHandler levelDB *store.LevelDB licenseService *service.LicenseService - sdServicePool *sd.ServicePool } -func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, sdPool *sd.ServicePool) *ConfigHandler { +func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler { return &ConfigHandler{ BaseHandler: handler.BaseHandler{App: app, DB: db}, levelDB: levelDB, - sdServicePool: sdPool, licenseService: licenseService, } } diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index ff5320ef..a9ff01c5 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -32,15 +32,15 @@ import ( type SdJobHandler struct { BaseHandler redis *redis.Client - pool *sd.ServicePool + service *sd.Service uploader *oss.UploaderManager snowflake *service.Snowflake leveldb *store.LevelDB } -func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { +func NewSdJobHandler(app *core.AppServer, db *gorm.DB, service *sd.Service, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { return &SdJobHandler{ - pool: pool, + service: service, uploader: manager, snowflake: snowflake, leveldb: levelDB, @@ -68,7 +68,7 @@ func (h *SdJobHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.pool.Clients.Put(uint(userId), client) + h.service.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -79,11 +79,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool { return false } - if !h.pool.HasAvailableService() { - resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!") - return false - } - if user.Power < h.App.SysConfig.SdPower { resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") return false @@ -164,14 +159,14 @@ func (h *SdJobHandler) Image(c *gin.Context) { return } - h.pool.PushTask(types.SdTask{ + h.service.PushTask(types.SdTask{ Id: int(job.Id), Type: types.TaskImage, Params: params, UserId: userId, }) - client := h.pool.Clients.Get(uint(job.UserId)) + client := h.service.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -328,11 +323,6 @@ func (h *SdJobHandler) Remove(c *gin.Context) { logger.Error("remove image failed: ", err) } - client := h.pool.Clients.Get(uint(job.UserId)) - if client != nil { - _ = client.Send([]byte(service.TaskStatusFinished)) - } - resp.SUCCESS(c) } diff --git a/api/main.go b/api/main.go index 287c6c07..26ca6fed 100644 --- a/api/main.go +++ b/api/main.go @@ -199,13 +199,11 @@ func main() { }), // Stable Diffusion 机器人 - fx.Provide(sd.NewServicePool), - fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) { - pool.InitServices(config.SdConfigs) - if pool.HasAvailableService() { - pool.CheckTaskNotify() - pool.CheckTaskStatus() - } + fx.Provide(sd.NewService), + fx.Invoke(func(s *sd.Service, config *types.AppConfig) { + s.Run() + s.CheckTaskStatus() + s.CheckTaskNotify() }), fx.Provide(suno.NewService), diff --git a/api/service/mj/client.go b/api/service/mj/client.go index 450b7d8b..35ed807a 100644 --- a/api/service/mj/client.go +++ b/api/service/mj/client.go @@ -50,12 +50,6 @@ type ImageRes struct { Channel string `json:"channel,omitempty"` } -type ErrRes struct { - Error struct { - Message string `json:"message"` - } `json:"error"` -} - type QueryRes struct { Action string `json:"action"` Buttons []struct { @@ -193,7 +187,6 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) { func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) { var res ImageRes - var errRes ErrRes session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true) if channel != "" { session = session.Where("api_url", channel) @@ -215,20 +208,14 @@ func (c *Client) doRequest(body interface{}, apiPath string, channel string) (Im SetHeader("Authorization", "Bearer "+apiKey.Value). SetBody(body). SetSuccessResult(&res). - SetErrorResult(&errRes). Post(apiURL) if err != nil { - errMsg := err.Error() - if r != nil { - errStr, _ := io.ReadAll(r.Body) - logger.Error("请求 API 出错:", string(errStr)) - errMsg = errMsg + " " + string(errStr) - } - return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg) + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) } if r.IsErrorState() { - return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + errMsg, _ := io.ReadAll(r.Body) + return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg)) } // update the api key last used time diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go deleted file mode 100644 index d0033f67..00000000 --- a/api/service/sd/pool.go +++ /dev/null @@ -1,128 +0,0 @@ -package sd - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "fmt" - "geekai/core/types" - "geekai/service" - "geekai/service/oss" - "geekai/store" - "geekai/store/model" - "time" - - "github.com/go-redis/redis/v8" - "gorm.io/gorm" -) - -type ServicePool struct { - services []*Service - taskQueue *store.RedisQueue - notifyQueue *store.RedisQueue - db *gorm.DB - Clients *types.LMap[uint, *types.WsClient] // UserId => Client - uploader *oss.UploaderManager - levelDB *store.LevelDB -} - -func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool { - services := make([]*Service, 0) - taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) - notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli) - - return &ServicePool{ - taskQueue: taskQueue, - notifyQueue: notifyQueue, - services: services, - db: db, - Clients: types.NewLMap[uint, *types.WsClient](), - uploader: manager, - levelDB: levelDB, - } -} - -func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) { - // stop old service - for _, s := range p.services { - s.Stop() - } - p.services = make([]*Service, 0) - - for k, config := range configs { - if config.Enabled == false { - continue - } - - // create sd service - name := fmt.Sprintf(" sd-service-%d", k) - service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB) - // run sd service - go func() { - service.Run() - }() - - p.services = append(p.services, service) - } -} - -// PushTask push a new mj task in to task queue -func (p *ServicePool) PushTask(task types.SdTask) { - logger.Debugf("add a new MidJourney task to the task list: %+v", task) - p.taskQueue.RPush(task) -} - -func (p *ServicePool) CheckTaskNotify() { - go func() { - logger.Info("Running Stable-Diffusion task notify checking ...") - for { - var message service.NotifyMessage - err := p.notifyQueue.LPop(&message) - if err != nil { - continue - } - client := p.Clients.Get(uint(message.UserId)) - if client == nil { - continue - } - err = client.Send([]byte(message.Message)) - if err != nil { - continue - } - } - }() -} - -// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 -func (p *ServicePool) CheckTaskStatus() { - go func() { - logger.Info("Running Stable-Diffusion task status checking ...") - for { - var jobs []model.SdJob - res := p.db.Where("progress < ?", 100).Find(&jobs) - if res.Error != nil { - time.Sleep(5 * time.Second) - continue - } - - for _, job := range jobs { - // 5 分钟还没完成的任务标记为失败 - if time.Now().Sub(job.CreatedAt) > time.Minute*5 { - job.Progress = 101 - job.ErrMsg = "任务超时" - p.db.Updates(&job) - } - } - time.Sleep(time.Second * 5) - } - }() -} - -// HasAvailableService check if it has available mj service in pool -func (p *ServicePool) HasAvailableService() bool { - return len(p.services) > 0 -} diff --git a/api/service/sd/service.go b/api/service/sd/service.go index d3b6c231..a6c9a856 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -16,7 +16,7 @@ import ( "geekai/store" "geekai/store/model" "geekai/utils" - "strings" + "github.com/go-redis/redis/v8" "time" "github.com/imroc/req/v3" @@ -29,79 +29,72 @@ var logger = logger2.GetLogger() type Service struct { httpClient *req.Client - config types.StableDiffusionConfig taskQueue *store.RedisQueue notifyQueue *store.RedisQueue db *gorm.DB uploadManager *oss.UploaderManager - name string // service name leveldb *store.LevelDB - running bool // 运行状态 + Clients *types.LMap[uint, *types.WsClient] // UserId => Client } -func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service { - config.ApiURL = strings.TrimRight(config.ApiURL, "/") +func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service { return &Service{ - name: name, - config: config, httpClient: req.C(), - taskQueue: taskQueue, - notifyQueue: notifyQueue, + taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli), db: db, leveldb: levelDB, + Clients: types.NewLMap[uint, *types.WsClient](), uploadManager: manager, - running: true, } } func (s *Service) Run() { - logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name) - for s.running { - var task types.SdTask - err := s.taskQueue.LPop(&task) - if err != nil { - logger.Errorf("taking task with error: %v", err) - continue - } + logger.Infof("Starting Stable-Diffusion job consumer") + go func() { + for { + var task types.SdTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } - // translate prompt - if utils.HasChinese(task.Params.Prompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini") - if err == nil { - task.Params.Prompt = content - } else { - logger.Warnf("error with translate prompt: %v", err) + // translate prompt + if utils.HasChinese(task.Params.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini") + if err == nil { + task.Params.Prompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } + } + + // translate negative prompt + if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini") + if err == nil { + task.Params.NegPrompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } + } + + logger.Infof("handle a new Stable-Diffusion task: %+v", task) + err = s.Txt2Img(task) + if err != nil { + logger.Error("绘画任务执行失败:", err.Error()) + // update the task progress + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ + "progress": service.FailTaskProgress, + "err_msg": err.Error(), + }) + // 通知前端,任务失败 + s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed}) + continue } } - - // translate negative prompt - if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini") - if err == nil { - task.Params.NegPrompt = content - } else { - logger.Warnf("error with translate prompt: %v", err) - } - } - - logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task) - err = s.Txt2Img(task) - if err != nil { - logger.Error("绘画任务执行失败:", err.Error()) - // update the task progress - s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ - "progress": service.FailTaskProgress, - "err_msg": err.Error(), - }) - // 通知前端,任务失败 - s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed}) - continue - } - } -} - -func (s *Service) Stop() { - s.running = false + }() } // Txt2ImgReq 文生图请求实体 @@ -163,12 +156,19 @@ func (s *Service) Txt2Img(task types.SdTask) error { } var res Txt2ImgResp var errChan = make(chan error) - apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL) + + var apiKey model.ApiKey + err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error + if err != nil { + return fmt.Errorf("no available Stable-Diffusion api key: %v", err) + } + + apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL) logger.Debugf("send image request to %s", apiURL) // send a request to sd api endpoint go func() { response, err := s.httpClient.R(). - SetHeader("Authorization", s.config.ApiKey). + SetHeader("Authorization", apiKey.Value). SetBody(body). SetSuccessResult(&res). Post(apiURL) @@ -181,6 +181,10 @@ func (s *Service) Txt2Img(task types.SdTask) error { return } + // update the last used time + apiKey.LastUsedAt = time.Now().Unix() + s.db.Updates(&apiKey) + // 保存 Base64 图片 imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0]) if err != nil { @@ -214,7 +218,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { _ = s.leveldb.Delete(task.Params.TaskId) return nil default: - err, resp := s.checkTaskProgress() + err, resp := s.checkTaskProgress(apiKey) // 更新任务进度 if err == nil && resp.Progress > 0 { s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) @@ -232,11 +236,11 @@ func (s *Service) Txt2Img(task types.SdTask) error { } // 执行任务 -func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { - apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) +func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) { + apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL) var res TaskProgressResp response, err := s.httpClient.R(). - SetHeader("Authorization", s.config.ApiKey). + SetHeader("Authorization", apiKey.Value). SetSuccessResult(&res). Get(apiURL) if err != nil { @@ -248,3 +252,54 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { return nil, &res } + +func (s *Service) PushTask(task types.SdTask) { + logger.Debugf("add a new MidJourney task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) CheckTaskNotify() { + go func() { + logger.Info("Running Stable-Diffusion task notify checking ...") + for { + var message service.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + client := s.Clients.Get(uint(message.UserId)) + if client == nil { + continue + } + err = client.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 +func (s *Service) CheckTaskStatus() { + go func() { + logger.Info("Running Stable-Diffusion task status checking ...") + for { + var jobs []model.SdJob + res := s.db.Where("progress < ?", 100).Find(&jobs) + if res.Error != nil { + time.Sleep(5 * time.Second) + continue + } + + for _, job := range jobs { + // 5 分钟还没完成的任务标记为失败 + if time.Now().Sub(job.CreatedAt) > time.Minute*5 { + job.Progress = service.FailTaskProgress + job.ErrMsg = "任务超时" + s.db.Updates(&job) + } + } + time.Sleep(time.Second * 5) + } + }() +}