diff --git a/api/core/app_server.go b/api/core/app_server.go index f0b5d07a..9bcd9689 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -169,9 +169,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { var tokenString string if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API tokenString = c.GetHeader(types.AdminAuthHeader) - } else if c.Request.URL.Path == "/api/chat/new" || - c.Request.URL.Path == "/api/mj/client" || - c.Request.URL.Path == "/api/sd/client" { + } else if c.Request.URL.Path == "/api/chat/new" { tokenString = c.Query("token") } else { tokenString = c.GetHeader(types.UserAuthHeader) diff --git a/api/core/config.go b/api/core/config.go index 632a7b68..5e02c821 100644 --- a/api/core/config.go +++ b/api/core/config.go @@ -33,7 +33,6 @@ func NewDefaultConfig() *types.AppConfig { BasePath: "./static/upload", }, }, - SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"}, WeChatBot: false, AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, } diff --git a/api/core/types/config.go b/api/core/types/config.go index 44ff0569..ebb69708 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -16,11 +16,11 @@ type AppConfig struct { Redis RedisConfig // redis 连接信息 ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs AesEncryptKey string - SmsConfig AliYunSmsConfig // AliYun send message service config - OSS OSSConfig // OSS config - MjConfigs []MidJourneyConfig // mj 绘画配置池子 - WeChatBot bool // 是否启用微信机器人 - SdConfig StableDiffusionConfig // sd 绘画配置 + SmsConfig AliYunSmsConfig // AliYun send message service config + OSS OSSConfig // OSS config + MjConfigs []MidJourneyConfig // mj AI draw service pool + WeChatBot bool // 是否启用微信机器人 + SdConfigs []StableDiffusionConfig // sd AI draw service pool XXLConfig XXLConfig AlipayConfig AlipayConfig diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 83768252..d81abfe4 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -8,47 +8,30 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "encoding/base64" "fmt" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" - "github.com/gorilla/websocket" "gorm.io/gorm" - "net/http" "time" ) type SdJobHandler struct { BaseHandler - redis *redis.Client - db *gorm.DB - service *sd.Service + redis *redis.Client + db *gorm.DB + pool *sd.ServicePool } -func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler { +func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool) *SdJobHandler { h := SdJobHandler{ - redis: redisCli, - db: db, - service: service, + db: db, + pool: pool, } h.App = app return &h } -// Client WebSocket 客户端,用于通知任务状态变更 -func (h *SdJobHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - return - } - - sessionId := c.Query("session_id") - client := types.NewWsClient(ws) - // 删除旧的连接 - h.service.Clients.Put(sessionId, client) - logger.Infof("New websocket connected, IP: %s", c.ClientIP()) -} - func (h *SdJobHandler) checkLimits(c *gin.Context) bool { user, err := utils.GetLoginUser(c, h.db) if err != nil { @@ -56,6 +39,11 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool { return false } + if !h.pool.HasAvailableService() { + resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!") + return false + } + if user.ImgCalls <= 0 { resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") return false @@ -67,11 +55,6 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool { // Image 创建一个绘画任务 func (h *SdJobHandler) Image(c *gin.Context) { - if !h.App.Config.SdConfig.Enabled { - resp.ERROR(c, "Stable Diffusion service is disabled") - return - } - if !h.checkLimits(c) { return } @@ -129,7 +112,6 @@ func (h *SdJobHandler) Image(c *gin.Context) { Params: utils.JsonEncode(params), Prompt: data.Prompt, Progress: 0, - Started: false, CreatedAt: time.Now(), } res := h.db.Create(&job) @@ -138,7 +120,7 @@ func (h *SdJobHandler) Image(c *gin.Context) { return } - h.service.PushTask(types.SdTask{ + h.pool.PushTask(types.SdTask{ Id: int(job.Id), SessionId: data.SessionId, Type: types.TaskImage, @@ -146,15 +128,7 @@ func (h *SdJobHandler) Image(c *gin.Context) { Params: params, UserId: userId, }) - var jobVo vo.SdJob - err := utils.CopyObject(job, &jobVo) - if err == nil { - // 推送任务到前端 - client := h.service.Clients.Get(data.SessionId) - if client != nil { - utils.ReplyChunkMessage(client, jobVo) - } - } + resp.SUCCESS(c) } @@ -193,12 +167,22 @@ func (h *SdJobHandler) JobList(c *gin.Context) { if err != nil { continue } + + if job.Progress == -1 { + h.db.Delete(&model.MidJourneyJob{Id: job.Id}) + } + if item.Progress < 100 { - // 30 分钟还没完成的任务直接删除 - if time.Now().Sub(item.CreatedAt) > time.Minute*30 { + // 10 分钟还没完成的任务直接删除 + if time.Now().Sub(item.CreatedAt) > time.Minute*10 { h.db.Delete(&item) continue } + // 正在运行中任务使用代理访问图片 + image, err := utils.DownloadImage(item.ImgURL, "") + if err == nil { + job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } } jobs = append(jobs, job) } diff --git a/api/main.go b/api/main.go index 46626b13..9f00da04 100644 --- a/api/main.go +++ b/api/main.go @@ -167,14 +167,7 @@ func main() { fx.Provide(mj.NewServicePool), // Stable Diffusion 机器人 - fx.Provide(sd.NewService), - fx.Invoke(func(config *types.AppConfig, service *sd.Service) { - if config.SdConfig.Enabled { - go func() { - service.Run() - }() - } - }), + fx.Provide(sd.NewServicePool), fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewHuPiPay), diff --git a/api/res/text2img.json b/api/res/text2img.json index 050daf75..c645454c 100644 --- a/api/res/text2img.json +++ b/api/res/text2img.json @@ -1,21 +1,21 @@ { "data": [ - "task(s95jqt5jr8yppcp)", - "A beautiful Chinese girl in a garden", + "task(owy5niy1sbbnlq0)", + "A beautiful Chinese girl plays the guitar on the beach. She is dressed in a flowing dress that matches the colors of the sunset. With her eyes closed, she strums the guitar with passion and confidence, her fingers dancing gracefully on the strings. The painting employs a vibrant color palette, capturing the warmth of the setting sun blending with the serene hues of the ocean. The artist uses a combination of impressionistic and realistic brushstrokes to convey both the girl's delicate features and the dynamic movement of the waves. The rendering effect creates a dream-like atmosphere, as if the viewer is being transported to a magical realm where music and nature intertwine. The picture is bathed in a soft, golden light, casting a warm glow on the girl's face, illuminating her joy and connection to the music she creates.", "", [], 30, - "Euler a", + "DPM++ 3M SDE Karras", 1, 1, 7, 512, 512, - true, + false, 0.7, 2, "Latent", - 10, + 0, 0, 0, "Use same checkpoint", @@ -33,6 +33,9 @@ 0, 0, 0, + null, + null, + null, false, false, "positive", @@ -55,13 +58,22 @@ false, false, 0, - [ - ], + null, + null, + false, + null, + null, + false, + null, + null, + false, + 50, + [], "", "", "" ], "event_data": null, - "fn_index": 95, - "session_hash": "eqwumnt3rov" + "fn_index": 316, + "session_hash": "ttr8efgt63g" } \ No newline at end of file diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index ad4e0c9d..efcd853e 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -60,7 +60,7 @@ func (p *ServicePool) PushTask(task types.MjTask) { p.taskQueue.RPush(task) } -// HasAvailableService check if has available mj service in pool +// HasAvailableService check if it has available mj service in pool func (p *ServicePool) HasAvailableService() bool { return len(p.services) > 0 } diff --git a/api/service/mj/service.go b/api/service/mj/service.go index e991ef01..754966e5 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -2,7 +2,6 @@ package mj import ( "chatplus/core/types" - "chatplus/service" "chatplus/service/oss" "chatplus/store" "chatplus/store/model" @@ -24,7 +23,6 @@ type Service struct { handledTaskNum int32 // already handled task number taskStartTimes map[int]time.Time // task start time, to check if the task is timeout taskTimeout int64 - snowflake *service.Snowflake } func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service { @@ -127,6 +125,12 @@ func (s *Service) Notify(data CBReq) { job.Hash = data.Image.Hash job.OrgURL = data.Image.URL + res = s.db.Updates(&job) + if res.Error != nil { + logger.Error("error with update job: ", res.Error) + return + } + // upload image if data.Status == Finished { imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) @@ -135,12 +139,7 @@ func (s *Service) Notify(data CBReq) { return } job.ImgURL = imgURL - } - - res = s.db.Updates(&job) - if res.Error != nil { - logger.Error("error with update job: ", res.Error) - return + s.db.Updates(&job) } if data.Status == Finished { diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go new file mode 100644 index 00000000..4a817983 --- /dev/null +++ b/api/service/sd/pool.go @@ -0,0 +1,52 @@ +package sd + +import ( + "chatplus/core/types" + "chatplus/service/oss" + "chatplus/store" + "fmt" + "github.com/go-redis/redis/v8" + "gorm.io/gorm" +) + +type ServicePool struct { + services []*Service + taskQueue *store.RedisQueue +} + +func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { + services := make([]*Service, 0) + queue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) + // create mj client and service + for k, config := range appConfig.SdConfigs { + if config.Enabled == false { + continue + } + + // create sd service + name := fmt.Sprintf("StableDifffusion Service-%d", k) + service := NewService(name, 4, 600, &config, queue, db, manager) + // run sd service + go func() { + service.Run() + }() + + services = append(services, service) + } + + return &ServicePool{ + taskQueue: queue, + services: services, + } +} + +// PushTask push a new mj task in to task queue +func (p *ServicePool) PushTask(task types.SdTask) { + logger.Debugf("add a new MidJourney task to the task list: %+v", task) + p.taskQueue.RPush(task) +} + +// HasAvailableService check if it has available mj service in pool +func (p *ServicePool) HasAvailableService() bool { + return len(p.services) > 0 +} diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 3d32cf7a..0a1881e2 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -5,84 +5,96 @@ import ( "chatplus/service/oss" "chatplus/store" "chatplus/store/model" - "chatplus/store/vo" "chatplus/utils" - "context" "encoding/json" "fmt" - "github.com/go-redis/redis/v8" "github.com/imroc/req/v3" "gorm.io/gorm" "io" "os" "strconv" + "sync/atomic" "time" ) // SD 绘画服务 -const RunningJobKey = "StableDiffusion_Running_Job" - type Service struct { - httpClient *req.Client - config *types.StableDiffusionConfig - taskQueue *store.RedisQueue - redis *redis.Client - db *gorm.DB - uploadManager *oss.UploaderManager - Clients *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池 + httpClient *req.Client + config *types.StableDiffusionConfig + taskQueue *store.RedisQueue + db *gorm.DB + uploadManager *oss.UploaderManager + name string // service name + maxHandleTaskNum int32 // max task number current service can handle + handledTaskNum int32 // already handled task number + taskStartTimes map[int]time.Time // task start time, to check if the task is timeout + taskTimeout int64 } -func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service { +func NewService(name string, maxTaskNum int32, timeout int64, config *types.StableDiffusionConfig, queue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service { return &Service{ - config: &config.SdConfig, - httpClient: req.C(), - redis: redisCli, - db: db, - uploadManager: manager, - Clients: types.NewLMap[string, *types.WsClient](), - taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli), + name: name, + config: config, + httpClient: req.C(), + taskQueue: queue, + db: db, + uploadManager: manager, + taskTimeout: timeout, + maxHandleTaskNum: maxTaskNum, + taskStartTimes: make(map[int]time.Time), } } func (s *Service) Run() { - logger.Info("Starting StableDiffusion job consumer.") - ctx := context.Background() for { - _, err := s.redis.Get(ctx, RunningJobKey).Result() - if err == nil { // 队列串行执行 + s.checkTasks() + if !s.canHandleTask() { + // current service is full, can not handle more task + // waiting for running task finish time.Sleep(time.Second * 3) continue } + var task types.SdTask - err = s.taskQueue.LPop(&task) + err := s.taskQueue.LPop(&task) if err != nil { logger.Errorf("taking task with error: %v", err) continue } - logger.Infof("Consuming Task: %+v", task) + logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task) err = s.Txt2Img(task) if err != nil { logger.Error("绘画任务执行失败:", err) - if task.RetryCount <= 5 { - s.taskQueue.RPush(task) - } - task.RetryCount += 1 - time.Sleep(time.Second * 3) + // update the task progress + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) + // release task num + atomic.AddInt32(&s.handledTaskNum, -1) continue } - // 更新任务的执行状态 - s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) - // 锁定任务执行通道,直到任务超时(5分钟) - s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5) + // lock the task until the execute timeout + s.taskStartTimes[task.Id] = time.Now() + atomic.AddInt32(&s.handledTaskNum, 1) } } -// PushTask 推送任务到队列 -func (s *Service) PushTask(task types.SdTask) { - logger.Infof("add a new Stable Diffusion Task: %+v", task) - s.taskQueue.RPush(task) +// check if current service instance can handle more task +func (s *Service) canHandleTask() bool { + handledNum := atomic.LoadInt32(&s.handledTaskNum) + return handledNum < s.maxHandleTaskNum +} + +// remove the expired tasks +func (s *Service) checkTasks() { + for k, t := range s.taskStartTimes { + if time.Now().Unix()-t.Unix() > s.taskTimeout { + delete(s.taskStartTimes, k) + atomic.AddInt32(&s.handledTaskNum, -1) + // delete task from database + s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") + } + } } // Txt2Img 文生图 API @@ -237,9 +249,8 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { } func (s *Service) callback(data CBReq) { - // 释放任务锁 - s.redis.Del(context.Background(), RunningJobKey) - client := s.Clients.Get(data.SessionId) + // release task num + atomic.AddInt32(&s.handledTaskNum, -1) if data.Success { // 任务成功 var job model.SdJob res := s.db.Where("id = ?", data.JobId).First(&job) @@ -259,13 +270,15 @@ func (s *Service) callback(data CBReq) { params.Seed = data.Seed if data.ImageName != "" { // 下载图片 - imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName) - imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false) - if err != nil { - logger.Error("error with download img: ", err.Error()) - return + job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName) + if data.Progress == 100 { + imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false) + if err != nil { + logger.Error("error with download img: ", err.Error()) + return + } + job.ImgURL = imageURL } - job.ImgURL = imageURL } job.Params = utils.JsonEncode(params) @@ -275,38 +288,16 @@ func (s *Service) callback(data CBReq) { return } - var jobVo vo.SdJob - err = utils.CopyObject(job, &jobVo) - if err != nil { - logger.Error("error with copy object: ", err) - return - } - - if data.Progress < 100 && data.ImageData != "" { - jobVo.ImgURL = data.ImageData - } - - logger.Infof("绘图进度:%d", data.Progress) + logger.Debugf("绘图进度:%d", data.Progress) // 扣减绘图次数 if data.Progress == 100 { - s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", jobVo.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) - } - // 推送任务到前端 - if client != nil { - utils.ReplyChunkMessage(client, jobVo) + s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) } + } else { // 任务失败 logger.Error("任务执行失败:", data.Message) - // 删除任务 - s.db.Delete(&model.SdJob{Id: uint(data.JobId)}) - // 推送消息到前端 - if client != nil { - utils.ReplyChunkMessage(client, vo.SdJob{ - Id: uint(data.JobId), - Progress: -1, - TaskId: data.TaskId, - }) - } + // update the task progress + s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumn("progress", -1) } } diff --git a/api/store/model/sd_job.go b/api/store/model/sd_job.go index 473321b9..65f3e6fe 100644 --- a/api/store/model/sd_job.go +++ b/api/store/model/sd_job.go @@ -11,7 +11,6 @@ type SdJob struct { Progress int Prompt string Params string - Started bool CreatedAt time.Time } diff --git a/api/store/vo/sd_job.go b/api/store/vo/sd_job.go index c4ae2308..26b3be16 100644 --- a/api/store/vo/sd_job.go +++ b/api/store/vo/sd_job.go @@ -15,5 +15,4 @@ type SdJob struct { Progress int `json:"progress"` Prompt string `json:"prompt"` CreatedAt time.Time `json:"created_at"` - Started bool `json:"started"` } diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 0bdd9afa..bc9b8bc4 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -266,7 +266,6 @@ 翻译并重写 - @@ -580,7 +579,7 @@ const fetchRunningJobs = (userId) => { } runningJobs.value = _jobs - setTimeout(() => fetchRunningJobs(userId), 10000) + setTimeout(() => fetchRunningJobs(userId), 5000) }).catch(e => { ElMessage.error("获取任务失败:" + e.message) @@ -591,7 +590,7 @@ const fetchFinishJobs = (userId) => { // 获取已完成的任务 httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => { finishedJobs.value = res.data - setTimeout(() => fetchFinishJobs(userId), 10000) + setTimeout(() => fetchFinishJobs(userId), 5000) }).catch(e => { ElMessage.error("获取任务失败:" + e.message) }) diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index 5c9945eb..c6d6723b 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -241,7 +241,7 @@ -
+
+
+ + + + + 翻译 + + + + + + + + 翻译并重写 + + +
+
反向提示词:
-
- - - +
+ 绘图可用额度:{{ imgCalls }}
@@ -478,21 +498,21 @@