From 5034a203454d3b0a183a2ec98ab104f8aada75d7 Mon Sep 17 00:00:00 2001 From: RockYang Date: Sun, 17 Sep 2023 18:03:45 +0800 Subject: [PATCH] feat: mj advance drawing page function is ready, use better task scheduling argorithm --- api/core/types/function.go | 26 +---- api/handler/mj_handler.go | 157 +++++++++++++++++++++++++------ api/main.go | 2 +- api/service/function/func_mj.go | 24 ----- api/service/mj_service.go | 14 ++- database/update-v3.1.3.sql | 4 +- web/src/assets/css/image-mj.styl | 15 ++- web/src/views/ImageMj.vue | 101 +++++++++++++++----- 8 files changed, 234 insertions(+), 109 deletions(-) diff --git a/api/core/types/function.go b/api/core/types/function.go index dd0a6e4e..7b0a74db 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -83,31 +83,7 @@ var InnerFunctions = []Function{ Properties: map[string]Property{ "prompt": { Type: "string", - Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文", - }, - "--ar": { - Type: "string", - Description: "图片长宽比,默认值 16:9", - }, - "--niji": { - Type: "string", - Description: "动漫模型版本,默认值空", - }, - "--s": { - Type: "string", - Description: "风格,stylize", - }, - "--seed": { - Type: "string", - Description: "随机种子", - }, - "--no": { - Type: "string", - Description: "负面提示词,指定不要什么元素或者风格,如果该参数中有中文的话,则需要翻译成英文", - }, - "--v": { - Type: "string", - Description: "模型版本,默认值: 5.2", + Description: "提示词,如果该参数中有中文的话,则需要翻译成英文。提示词中的参数作为提示的一部分,不要删除", }, }, Required: []string{}, diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 13ed5daf..ff906260 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -14,7 +14,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "gorm.io/gorm" - "net/http" + "strings" "sync" "time" ) @@ -24,7 +24,6 @@ type TaskStatus string const ( Stopped = TaskStatus("Stopped") Finished = TaskStatus("Finished") - Running = TaskStatus("Running") ) type Image struct { @@ -117,18 +116,25 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro return nil, false } + var job model.MidJourneyJob + res := h.db.Where("message_id = ?", data.MessageId).First(&job) + if res.Error == nil && data.Status == Finished { + logger.Warn("重复消息:", data.MessageId) + return nil, false + } + if task.Src == service.TaskSrcImg { // 绘画任务 - logger.Error(err) var job model.MidJourneyJob - res := h.db.First(&job, task.Id) + res := h.db.Where("id = ?", task.Id).First(&job) if res.Error != nil { - logger.Warn("非法任务:", err) + logger.Warn("非法任务:", res.Error) return nil, false } job.MessageId = data.MessageId job.ReferenceId = data.ReferenceId job.Progress = data.Progress job.Prompt = data.Prompt + job.Hash = data.Image.Hash // 任务完成,将最终的图片下载下来 if data.Progress == 100 { @@ -144,18 +150,11 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro } res = h.db.Updates(&job) if res.Error != nil { - logger.Error("error with update job: ", err.Error()) + logger.Error("error with update job: ", res.Error) return res.Error, false } } else if task.Src == service.TaskSrcChat { // 聊天任务 - var job model.MidJourneyJob - res := h.db.Where("message_id = ?", data.MessageId).First(&job) - if res.Error == nil { - logger.Warn("重复消息:", data.MessageId) - return nil, false - } - wsClient := h.App.MjTaskClients.Get(task.Id) if data.Status == Finished { if wsClient != nil && data.ReferenceId != "" { @@ -233,15 +232,68 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro return nil, true } -// Proxy 通过代理访问 discord f服务器图片 -func (h *MidJourneyHandler) Proxy(c *gin.Context) { - imgURL := c.Query("url") - imageData, err := utils.DownloadImage(imgURL, h.App.Config.ProxyURL) - if err != nil { - c.String(http.StatusOK, err.Error()) +// Image 创建一个绘画任务 +func (h *MidJourneyHandler) Image(c *gin.Context) { + var data struct { + Prompt string `json:"prompt"` + Rate string `json:"rate"` + Model string `json:"model"` + Chaos int `json:"chaos"` + Raw bool `json:"raw"` + Seed int64 `json:"seed"` + Stylize int `json:"stylize"` + Img string `json:"img"` + Weight float32 `json:"weight"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) return } - c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + var prompt = data.Prompt + if data.Rate != "" && !strings.Contains(prompt, "--ar") { + prompt += " --ar " + data.Rate + } + if data.Seed > 0 && !strings.Contains(prompt, "--seed") { + prompt += fmt.Sprintf(" --seed %d", data.Seed) + } + if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") { + prompt += fmt.Sprintf(" --s %d", data.Stylize) + } + if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { + prompt += fmt.Sprintf(" --c %d", data.Chaos) + } + if data.Img != "" { + prompt = fmt.Sprintf("%s %s", data.Img, prompt) + if data.Weight > 0 { + prompt += fmt.Sprintf(" --iw %f", data.Weight) + } + } + if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { + prompt += data.Model + } + + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + job := model.MidJourneyJob{ + Type: service.Image.String(), + UserId: userId, + Progress: 0, + Prompt: prompt, + CreatedAt: time.Now(), + } + if res := h.db.Create(&job); res.Error != nil { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + + h.mjService.PushTask(service.MjTask{ + Id: fmt.Sprintf("%d", job.Id), + Src: service.TaskSrcImg, + Type: service.Image, + Prompt: prompt, + UserId: userId, + }) + resp.SUCCESS(c) } type reqVo struct { @@ -264,13 +316,32 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } - userId, _ := c.Get(types.LoginUserID) + idValue, _ := c.Get(types.LoginUserID) + jobId := data.SessionId + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + src := service.TaskSrc(data.Src) + if src == service.TaskSrcImg { + job := model.MidJourneyJob{ + Type: service.Upscale.String(), + UserId: userId, + Hash: data.MessageHash, + Progress: 0, + Prompt: data.Prompt, + CreatedAt: time.Now(), + } + if res := h.db.Create(&job); res.Error == nil { + jobId = fmt.Sprintf("%d", job.Id) + } else { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + } h.mjService.PushTask(service.MjTask{ - Id: data.SessionId, - Src: service.TaskSrc(data.Src), + Id: jobId, + Src: src, Type: service.Upscale, Prompt: data.Prompt, - UserId: utils.IntValue(utils.InterfaceToString(userId), 0), + UserId: userId, RoleId: data.RoleId, Icon: data.Icon, ChatId: data.ChatId, @@ -298,13 +369,33 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { return } - userId, _ := c.Get(types.LoginUserID) + idValue, _ := c.Get(types.LoginUserID) + jobId := data.SessionId + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + src := service.TaskSrc(data.Src) + if src == service.TaskSrcImg { + job := model.MidJourneyJob{ + Type: service.Variation.String(), + UserId: userId, + ImgURL: "", + Hash: data.MessageHash, + Progress: 0, + Prompt: data.Prompt, + CreatedAt: time.Now(), + } + if res := h.db.Create(&job); res.Error == nil { + jobId = fmt.Sprintf("%d", job.Id) + } else { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + } h.mjService.PushTask(service.MjTask{ - Id: data.SessionId, - Src: service.TaskSrc(data.Src), + Id: jobId, + Src: src, Type: service.Variation, Prompt: data.Prompt, - UserId: utils.IntValue(utils.InterfaceToString(userId), 0), + UserId: userId, RoleId: data.RoleId, Icon: data.Icon, ChatId: data.ChatId, @@ -332,9 +423,9 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { var res *gorm.DB userId, _ := c.Get(types.LoginUserID) if status == 1 { - res = h.db.Where("user_id = ? AND progress = 100", userId).Find(&items) + res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items) } else { - res = h.db.Where("user_id = ? AND progress < 100", userId).Find(&items) + res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items) } if res.Error != nil { resp.ERROR(c, types.NoData) @@ -348,6 +439,12 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { if err != nil { continue } + if item.Progress < 100 && item.ImgURL != "" { // 正在运行中任务使用代理访问图片 + image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) + if err == nil { + job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + } jobs = append(jobs, job) } resp.SUCCESS(c, jobs) diff --git a/api/main.go b/api/main.go index 5bbdb951..d9a2559e 100644 --- a/api/main.go +++ b/api/main.go @@ -191,10 +191,10 @@ func main() { fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { group := s.Engine.Group("/api/mj/") group.POST("notify", h.Notify) + group.POST("image", h.Image) group.POST("upscale", h.Upscale) group.POST("variation", h.Variation) group.GET("jobs", h.JobList) - group.GET("proxy", h.Proxy) }), // 管理后台控制器 diff --git a/api/service/function/func_mj.go b/api/service/function/func_mj.go index 035645cc..19c84407 100644 --- a/api/service/function/func_mj.go +++ b/api/service/function/func_mj.go @@ -3,7 +3,6 @@ package function import ( "chatplus/service" "chatplus/utils" - "fmt" ) // AI 绘画函数 @@ -22,29 +21,6 @@ func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney { func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { logger.Infof("MJ 绘画参数:%+v", params) prompt := utils.InterfaceToString(params["prompt"]) - if !utils.IsEmptyValue(params["ar"]) { - prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"]) - delete(params, "ar") - } - if !utils.IsEmptyValue(params["s"]) { - prompt = fmt.Sprintf("%s --s %s", prompt, params["s"]) - delete(params, "s") - } - if !utils.IsEmptyValue(params["seed"]) { - prompt = fmt.Sprintf("%s --seed %s", prompt, params["seed"]) - delete(params, "seed") - } - if !utils.IsEmptyValue(params["no"]) { - prompt = fmt.Sprintf("%s --no %s", prompt, params["no"]) - delete(params, "no") - } - if !utils.IsEmptyValue(params["niji"]) { - prompt = fmt.Sprintf("%s --niji %s", prompt, params["niji"]) - delete(params, "niji") - } else { - prompt = prompt + " --v 5.2" - } - f.service.PushTask(service.MjTask{ Id: utils.InterfaceToString(params["session_id"]), Src: service.TaskSrcChat, diff --git a/api/service/mj_service.go b/api/service/mj_service.go index 1f583771..ae23d33e 100644 --- a/api/service/mj_service.go +++ b/api/service/mj_service.go @@ -4,12 +4,14 @@ import ( "chatplus/core/types" logger2 "chatplus/logger" "chatplus/store" + "chatplus/store/model" "chatplus/utils" "context" "errors" "fmt" "github.com/go-redis/redis/v8" "github.com/imroc/req/v3" + "gorm.io/gorm" "time" ) @@ -58,12 +60,14 @@ type MjService struct { client *req.Client taskQueue *store.RedisQueue redis *redis.Client + db *gorm.DB } -func NewMjService(appConfig *types.AppConfig, client *redis.Client) *MjService { +func NewMjService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *MjService { return &MjService{ config: appConfig.ExtConfig, redis: client, + db: db, taskQueue: store.NewRedisQueue("midjourney_task_queue", client), client: req.C().SetTimeout(30 * time.Second)} } @@ -104,9 +108,11 @@ func (s *MjService) Run() { } if err != nil { logger.Error("绘画任务执行失败:", err) - //if task.RetryCount > 5 { - // continue - //} + if task.RetryCount > 5 { + // 取消并删除任务 + s.db.Where("id = ?", task.Id).Delete(&model.MidJourneyJob{}) + continue + } task.RetryCount += 1 s.taskQueue.RPush(task) // TODO: 执行失败通知聊天客户端 diff --git a/database/update-v3.1.3.sql b/database/update-v3.1.3.sql index e86a2934..28d0b729 100644 --- a/database/update-v3.1.3.sql +++ b/database/update-v3.1.3.sql @@ -7,4 +7,6 @@ ALTER TABLE `chatgpt_mj_jobs` ADD `hash` VARCHAR(100) NULL DEFAULT NULL COMMENT ALTER TABLE `chatgpt_mj_jobs` ADD `img_url` VARCHAR(255) NULL DEFAULT NULL COMMENT '图片URL' AFTER `prompt`; -- 2023-09-15 -ALTER TABLE `chatgpt_mj_jobs` ADD `type` VARCHAR(20) NULL DEFAULT 'image' COMMENT '任务类别' AFTER `user_id`; \ No newline at end of file +ALTER TABLE `chatgpt_mj_jobs` ADD `type` VARCHAR(20) NULL DEFAULT 'image' COMMENT '任务类别' AFTER `user_id`; +ALTER TABLE `chatgpt_mj_jobs` DROP INDEX `message_id`; +ALTER TABLE `chatgpt_mj_jobs` ADD INDEX(`message_id`); \ No newline at end of file diff --git a/web/src/assets/css/image-mj.styl b/web/src/assets/css/image-mj.styl index a8fdc40c..85362067 100644 --- a/web/src/assets/css/image-mj.styl +++ b/web/src/assets/css/image-mj.styl @@ -234,7 +234,6 @@ display flex justify-content center align-items center - background-color: rgba(0, 0, 0, 0.5) span { font-size 20px @@ -249,6 +248,8 @@ .finish-job-list { .job-item { margin-bottom 20px + width 100% + height 100% .opt { .opt-line { @@ -289,12 +290,24 @@ height 100% max-height 240px + img { + height 240px + } + + .el-image-viewer__wrapper { + img { + width auto + height auto + } + } + .image-slot { display flex flex-flow column justify-content center align-items center height 100% + min-height 200px color #ffffff .iconfont { diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index cd45e707..48635187 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -60,7 +60,7 @@