From e35c34ad9a056482a8d8a98cb2749d02bed26eb0 Mon Sep 17 00:00:00 2001 From: RockYang Date: Sat, 11 May 2024 17:27:14 +0800 Subject: [PATCH] enable to update AI Drawing configuarations in admin console page --- CHANGELOG.md | 5 + api/core/types/task.go | 7 +- api/handler/admin/config_handler.go | 60 ++++++++++- api/handler/admin/user_handler.go | 4 +- api/handler/user_handler.go | 4 +- api/main.go | 7 +- api/service/mj/pool.go | 78 ++++++++------- api/service/mj/service.go | 9 +- api/service/sd/pool.go | 45 +++++---- api/service/sd/service.go | 13 ++- web/public/index.html | 2 +- web/src/router.js | 4 +- web/src/views/ImageSd.vue | 28 +++++- web/src/views/admin/AIDrawing.vue | 150 ++++++++++++++++++++++------ web/src/views/admin/Login.vue | 31 +++++- web/src/views/admin/SysConfig.vue | 4 +- 16 files changed, 343 insertions(+), 108 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cb19854..a046c9e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ * 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位 * Bug修复:修复思维导图不扣费的Bug * Bug修复:修复管理后台角色删除失败的Bug +* Bug修复:兼容最新版秋叶SD懒人包的 SD API,新增 scheduler 参数 +* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY ## v4.0.5 @@ -45,6 +47,7 @@ * 功能新增:支持管理后台 Logo 修改 ## 4.0.2 + * 功能新增:支持前端菜单可以配置 * 功能优化:在登录和注册界面标题显示软件版本号 * 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数 @@ -54,6 +57,7 @@ * 功能新增:管理后台登录页面增加行为验证码,防止爆破 ## v4.0.1 + * 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion 发行版,稳定性更强一些 * 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容 @@ -63,6 +67,7 @@ * Bug修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug ## v4.0.0 + 非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。 只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力... diff --git a/api/core/types/task.go b/api/core/types/task.go index fbae9a3a..6b6a364c 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -55,9 +55,10 @@ type SdTaskParams struct { NegPrompt string `json:"neg_prompt"` // 反向提示词 Steps int `json:"steps"` // 迭代步数,默认20 Sampler string `json:"sampler"` // 采样器 - FaceFix bool `json:"face_fix"` // 面部修复 - CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7 - Seed int64 `json:"seed"` // 随机数种子 + Scheduler string `json:"scheduler"` + FaceFix bool `json:"face_fix"` // 面部修复 + CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7 + Seed int64 `json:"seed"` // 随机数种子 Height int `json:"height"` Width int `json:"width"` HdFix bool `json:"hd_fix"` // 启用高清修复 diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index 4987e58e..6cf571e2 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -12,6 +12,8 @@ import ( "geekai/core/types" "geekai/handler" "geekai/service" + "geekai/service/mj" + "geekai/service/sd" "geekai/store" "geekai/store/model" "geekai/utils" @@ -26,10 +28,18 @@ type ConfigHandler struct { handler.BaseHandler levelDB *store.LevelDB licenseService *service.LicenseService + mjServicePool *mj.ServicePool + sdServicePool *sd.ServicePool } -func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler { - return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, levelDB: levelDB, licenseService: licenseService} +func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler { + return &ConfigHandler{ + BaseHandler: handler.BaseHandler{App: app, DB: db}, + levelDB: levelDB, + mjServicePool: mjPool, + sdServicePool: sdPool, + licenseService: licenseService, + } } func (h *ConfigHandler) Update(c *gin.Context) { @@ -138,3 +148,49 @@ func (h *ConfigHandler) GetDrawingConfig(c *gin.Context) { "sd": h.App.Config.SdConfigs, }) } + +// SaveDrawingConfig 保存AI绘画配置 +func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) { + var data struct { + Sd []types.StableDiffusionConfig `json:"sd"` + MjPlus []types.MjPlusConfig `json:"mj_plus"` + MjProxy []types.MjProxyConfig `json:"mj_proxy"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + changed := false + if configChanged(data.Sd, h.App.Config.SdConfigs) { + logger.Debugf("SD 配置变动了") + h.App.Config.SdConfigs = data.Sd + h.sdServicePool.InitServices(data.Sd) + changed = true + } + + if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) { + logger.Debugf("MidJourney 配置变动了") + h.App.Config.MjPlusConfigs = data.MjPlus + h.App.Config.MjProxyConfigs = data.MjProxy + h.mjServicePool.InitServices(data.MjPlus, data.MjProxy) + changed = true + } + + if changed { + err := core.SaveConfig(h.App.Config) + if err != nil { + resp.ERROR(c, "更新配置文档失败!") + return + } + } + + resp.SUCCESS(c) + +} + +func configChanged(c1 interface{}, c2 interface{}) bool { + encode1 := utils.JsonEncode(c1) + encode2 := utils.JsonEncode(c2) + return utils.Md5(encode1) != utils.Md5(encode2) +} diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go index 4148edd8..965c06f3 100644 --- a/api/handler/admin/user_handler.go +++ b/api/handler/admin/user_handler.go @@ -8,6 +8,7 @@ package admin // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "fmt" "geekai/core" "geekai/core/types" "geekai/handler" @@ -16,7 +17,6 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" - "fmt" "time" "github.com/gin-gonic/gin" @@ -87,7 +87,7 @@ func (h *UserHandler) Save(c *gin.Context) { // 检测最大注册人数 var totalUser int64 h.DB.Model(&model.User{}).Count(&totalUser) - if int(totalUser) >= h.licenseService.GetLicense().UserNum { + if h.licenseService.GetLicense().UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().UserNum { resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") return } diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index 6068f565..5df76df1 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -8,6 +8,7 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "fmt" "geekai/core" "geekai/core/types" "geekai/service" @@ -15,7 +16,6 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" - "fmt" "strings" "time" @@ -71,7 +71,7 @@ func (h *UserHandler) Register(c *gin.Context) { // 检测最大注册人数 var totalUser int64 h.DB.Model(&model.User{}).Count(&totalUser) - if int(totalUser) >= h.licenseService.GetLicense().UserNum { + if h.licenseService.GetLicense().UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().UserNum { resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") return } diff --git a/api/main.go b/api/main.go index a19ed49e..55a7075a 100644 --- a/api/main.go +++ b/api/main.go @@ -190,7 +190,8 @@ func main() { // MidJourney service pool fx.Provide(mj.NewServicePool), - fx.Invoke(func(pool *mj.ServicePool) { + fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) { + pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs) if pool.HasAvailableService() { pool.DownloadImages() pool.CheckTaskNotify() @@ -200,7 +201,8 @@ func main() { // Stable Diffusion 机器人 fx.Provide(sd.NewServicePool), - fx.Invoke(func(pool *sd.ServicePool) { + fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) { + pool.InitServices(config.SdConfigs) if pool.HasAvailableService() { pool.CheckTaskNotify() pool.CheckTaskStatus() @@ -303,6 +305,7 @@ func main() { group.POST("active", h.Active) group.GET("config/get/license", h.GetLicense) group.GET("config/get/draw", h.GetDrawingConfig) + group.POST("config/update/draw", h.SaveDrawingConfig) }), fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { group := s.Engine.Group("/api/admin/") diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index 0b0c93c5..dc8bdee3 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -31,48 +31,15 @@ type ServicePool struct { db *gorm.DB uploaderManager *oss.UploaderManager Clients *types.LMap[uint, *types.WsClient] // UserId => Client + licenseService *service.LicenseService } var logger = logger2.GetLogger() -func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, licenseService *service.LicenseService) *ServicePool { +func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool { services := make([]*Service, 0) taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) - - for k, config := range appConfig.MjPlusConfigs { - if config.Enabled == false { - continue - } - err := licenseService.IsValidApiURL(config.ApiURL) - if err != nil { - logger.Error(err) - continue - } - - cli := NewPlusClient(config) - name := fmt.Sprintf("mj-plus-service-%d", k) - plusService := NewService(name, taskQueue, notifyQueue, db, cli) - go func() { - plusService.Run() - }() - services = append(services, plusService) - } - - // for mid-journey proxy - for k, config := range appConfig.MjProxyConfigs { - if config.Enabled == false { - continue - } - cli := NewProxyClient(config) - name := fmt.Sprintf("mj-proxy-service-%d", k) - proxyService := NewService(name, taskQueue, notifyQueue, db, cli) - go func() { - proxyService.Run() - }() - services = append(services, proxyService) - } - return &ServicePool{ taskQueue: taskQueue, notifyQueue: notifyQueue, @@ -80,6 +47,47 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa uploaderManager: manager, db: db, Clients: types.NewLMap[uint, *types.WsClient](), + licenseService: licenseService, + } +} + +func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) { + // stop old service + for _, s := range p.services { + s.Stop() + } + + for k, config := range plusConfigs { + if config.Enabled == false { + continue + } + err := p.licenseService.IsValidApiURL(config.ApiURL) + if err != nil { + logger.Errorf("创建 MJ-PLUS 服务失败:%v", err) + continue + } + + cli := NewPlusClient(config) + name := fmt.Sprintf("mj-plus-service-%d", k) + plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) + go func() { + plusService.Run() + }() + p.services = append(p.services, plusService) + } + + // for mid-journey proxy + for k, config := range proxyConfigs { + if config.Enabled == false { + continue + } + cli := NewProxyClient(config) + name := fmt.Sprintf("mj-proxy-service-%d", k) + proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) + go func() { + proxyService.Run() + }() + p.services = append(p.services, proxyService) } } diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 9a87b5b8..56cf3a40 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -28,6 +28,7 @@ type Service struct { taskQueue *store.RedisQueue notifyQueue *store.RedisQueue db *gorm.DB + running bool } func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { @@ -37,12 +38,13 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red taskQueue: taskQueue, notifyQueue: notifyQueue, Client: cli, + running: true, } } func (s *Service) Run() { logger.Infof("Starting MidJourney job consumer for %s", s.Name) - for { + for s.running { var task types.MjTask err := s.taskQueue.LPop(&task) if err != nil { @@ -125,6 +127,11 @@ func (s *Service) Run() { } } +func (s *Service) Stop() { + s.running = false + s.Client = nil +} + type CBReq struct { Id string `json:"id"` Action string `json:"action"` diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go index 22110f40..776da9a8 100644 --- a/api/service/sd/pool.go +++ b/api/service/sd/pool.go @@ -25,28 +25,14 @@ type ServicePool struct { notifyQueue *store.RedisQueue db *gorm.DB Clients *types.LMap[uint, *types.WsClient] // UserId => Client + uploader *oss.UploaderManager + levelDB *store.LevelDB } -func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, levelDB *store.LevelDB) *ServicePool { +func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool { services := make([]*Service, 0) taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli) - // create mj client and service - for _, config := range appConfig.SdConfigs { - if config.Enabled == false { - continue - } - - // create sd service - name := fmt.Sprintf("StableDifffusion Service-%s", config.Model) - service := NewService(name, config, taskQueue, notifyQueue, db, manager, levelDB) - // run sd service - go func() { - service.Run() - }() - - services = append(services, service) - } return &ServicePool{ taskQueue: taskQueue, @@ -54,6 +40,31 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa services: services, db: db, Clients: types.NewLMap[uint, *types.WsClient](), + uploader: manager, + levelDB: levelDB, + } +} + +func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) { + // stop old service + for _, s := range p.services { + s.Stop() + } + + for k, config := range configs { + if config.Enabled == false { + continue + } + + // create sd service + name := fmt.Sprintf(" sd-service-%d", k) + service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB) + // run sd service + go func() { + service.Run() + }() + + p.services = append(p.services, service) } } diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 4f6fb2c7..2b0f27ff 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -33,6 +33,7 @@ type Service struct { uploadManager *oss.UploaderManager name string // service name leveldb *store.LevelDB + running bool // 运行状态 } func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service { @@ -46,18 +47,20 @@ func NewService(name string, config types.StableDiffusionConfig, taskQueue *stor db: db, leveldb: levelDB, uploadManager: manager, + running: true, } } func (s *Service) Run() { - for { + logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name) + for s.running { var task types.SdTask err := s.taskQueue.LPop(&task) if err != nil { logger.Errorf("taking task with error: %v", err) continue } - + logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task) // translate prompt if utils.HasChinese(task.Params.Prompt) { content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt)) @@ -94,6 +97,10 @@ func (s *Service) Run() { } } +func (s *Service) Stop() { + s.running = false +} + // Txt2ImgReq 文生图请求实体 type Txt2ImgReq struct { Prompt string `json:"prompt"` @@ -104,6 +111,7 @@ type Txt2ImgReq struct { Width int `json:"width"` Height int `json:"height"` SamplerName string `json:"sampler_name"` + Scheduler string `json:"scheduler"` EnableHr bool `json:"enable_hr,omitempty"` HrScale int `json:"hr_scale,omitempty"` HrUpscaler string `json:"hr_upscaler,omitempty"` @@ -137,6 +145,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { Width: task.Params.Width, Height: task.Params.Height, SamplerName: task.Params.Sampler, + Scheduler: task.Params.Scheduler, ForceTaskId: task.Params.TaskId, } if task.Params.Seed > 0 { diff --git a/web/public/index.html b/web/public/index.html index 4a3b2b4e..4f76c8bb 100644 --- a/web/public/index.html +++ b/web/public/index.html @@ -6,7 +6,7 @@ - ChatGPT-Plus + Geek-AI 创作助手 diff --git a/web/src/router.js b/web/src/router.js index 9c8739d9..3f31698c 100644 --- a/web/src/router.js +++ b/web/src/router.js @@ -105,7 +105,7 @@ const routes = [ { path: '/admin/login', name: 'admin-login', - meta: {title: 'ChatPuls 控制台登录'}, + meta: {title: 'Geek-AI 控制台登录'}, component: () => import('@/views/admin/Login.vue'), }, { @@ -113,7 +113,7 @@ const routes = [ path: '/admin', redirect: '/admin/dashboard', component: () => import("@/views/admin/Home.vue"), - meta: {title: 'ChatPuls 管理后台'}, + meta: {title: 'Geek-AI 控制台'}, children: [ { path: '/admin/dashboard', diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index 2f731ba0..991fbe2c 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -29,6 +29,28 @@ +
+ + + +
+