From cc7271aa732d19bd605a6067635885a5759f2896 Mon Sep 17 00:00:00 2001 From: RockYang Date: Thu, 28 Sep 2023 18:09:45 +0800 Subject: [PATCH] feat: stable diffusion page is ready --- api/core/app_server.go | 4 +- api/core/types/config.go | 18 +- api/core/types/task.go | 26 +- api/handler/mj_handler.go | 31 +- api/handler/sd_handler.go | 514 +++++++++++--------------- api/handler/sms_handler.go | 2 +- api/handler/user_handler.go | 4 +- api/main.go | 15 + api/service/mj/service.go | 6 +- api/service/oss/aliyun_oss.go | 10 +- api/service/oss/localstorage.go | 8 +- api/service/oss/minio_oss.go | 10 +- api/service/oss/qiniu_oss.go | 10 +- api/service/oss/uploader.go | 2 +- api/service/sd/client.go | 169 --------- api/service/sd/sd_service.go | 72 ---- api/service/sd/service.go | 300 ++++++++++++++++ api/service/sd/types.go | 278 ++++++++------- web/src/assets/css/image-sd.css | 187 ++++++++++ web/src/assets/css/image-sd.styl | 255 +++++++++++++ web/src/views/ImageMj.vue | 9 +- web/src/views/ImageSd.vue | 575 ++++++++++++++++++++++++++++-- web/src/views/admin/SysConfig.vue | 4 +- 23 files changed, 1730 insertions(+), 779 deletions(-) delete mode 100644 api/service/sd/client.go delete mode 100644 api/service/sd/sd_service.go create mode 100644 api/service/sd/service.go create mode 100644 web/src/assets/css/image-sd.css create mode 100644 web/src/assets/css/image-sd.styl diff --git a/api/core/app_server.go b/api/core/app_server.go index aa7016bd..1bdc6635 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -157,7 +157,9 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { var tokenString string if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API tokenString = c.GetHeader(types.AdminAuthHeader) - } else if c.Request.URL.Path == "/api/chat/new" || c.Request.URL.Path == "/api/mj/client" { + } else if c.Request.URL.Path == "/api/chat/new" || + c.Request.URL.Path == "/api/mj/client" || + c.Request.URL.Path == "/api/sd/client" { tokenString = c.Query("token") } else { tokenString = c.GetHeader(types.UserAuthHeader) diff --git a/api/core/types/config.go b/api/core/types/config.go index 10aaa04a..4f0c1a19 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -101,13 +101,13 @@ type ModelAPIConfig struct { } type SystemConfig struct { - Title string `json:"title"` - AdminTitle string `json:"admin_title"` - Models []string `json:"models"` - UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用 - InitImgCalls int `json:"init_img_calls"` - VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数 - EnabledRegister bool `json:"enabled_register"` - EnabledMsgService bool `json:"enabled_msg_service"` - EnabledDraw bool `json:"enabled_draw"` // 启动 AI 绘画功能 + Title string `json:"title"` + AdminTitle string `json:"admin_title"` + Models []string `json:"models"` + UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用 + InitImgCalls int `json:"init_img_calls"` + VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数 + EnabledRegister bool `json:"enabled_register"` + EnabledMsg bool `json:"enabled_msg"` // 启用短信验证码服务 + EnabledDraw bool `json:"enabled_draw"` // 启动 AI 绘画功能 } diff --git a/api/core/types/task.go b/api/core/types/task.go index a91d30cb..7256903e 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -40,7 +40,7 @@ type MjTask struct { } type SdTask struct { - Id int `json:"id"` + Id int `json:"id"` // job 数据库ID SessionId string `json:"session_id"` Src TaskSrc `json:"src"` Type TaskType `json:"type"` @@ -52,18 +52,18 @@ type SdTask struct { type SdTaskParams struct { TaskId string `json:"task_id"` - Prompt string `json:"prompt"` - NegativePrompt string `json:"negative_prompt"` - Steps int `json:"steps"` - Sampler string `json:"sampler"` - FaceFix bool `json:"face_fix"` - CfgScale float32 `json:"cfg_scale"` - Seed int64 `json:"seed"` + Prompt string `json:"prompt"` // 提示词 + NegativePrompt string `json:"negative_prompt"` // 反向提示词 + Steps int `json:"steps"` // 迭代步数,默认20 + Sampler string `json:"sampler"` // 采样器 + FaceFix bool `json:"face_fix"` // 面部修复 + CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7 + Seed int64 `json:"seed"` // 随机数种子 Height int `json:"height"` Width int `json:"width"` - HdFix bool `json:"hd_fix"` - HdRedrawRate float32 `json:"hd_redraw_rate"` - HdScale int `json:"hd_scale"` - HdScaleAlg string `json:"hd_scale_alg"` - HdSampleNum int `json:"hd_sample_num"` + HdFix bool `json:"hd_fix"` // 启用高清修复 + HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度 + HdScale int `json:"hd_scale"` // 放大倍数 + HdScaleAlg string `json:"hd_scale_alg"` // 放大算法 + HdSteps int `json:"hd_steps"` // 高清修复迭代步数 } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index a0c270b2..eff11fe6 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -4,7 +4,6 @@ import ( "chatplus/core" "chatplus/core/types" "chatplus/service/mj" - "chatplus/service/oss" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" @@ -17,33 +16,25 @@ import ( "gorm.io/gorm" "net/http" "strings" - "sync" "time" ) type MidJourneyHandler struct { BaseHandler - redis *redis.Client - db *gorm.DB - mjService *mj.Service - uploaderManager *oss.UploaderManager - lock sync.Mutex - clients *types.LMap[string, *types.WsClient] + redis *redis.Client + db *gorm.DB + mjService *mj.Service } func NewMidJourneyHandler( app *core.AppServer, client *redis.Client, db *gorm.DB, - manager *oss.UploaderManager, mjService *mj.Service) *MidJourneyHandler { h := MidJourneyHandler{ - redis: client, - db: db, - uploaderManager: manager, - lock: sync.Mutex{}, - mjService: mjService, - clients: types.NewLMap[string, *types.WsClient](), + redis: client, + db: db, + mjService: mjService, } h.App = app return &h @@ -59,9 +50,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { sessionId := c.Query("session_id") client := types.NewWsClient(ws) - // 删除旧的连接 - h.clients.Delete(sessionId) - h.clients.Put(sessionId, client) + h.mjService.Clients.Put(sessionId, client) logger.Infof("New websocket connected, IP: %s", c.ClientIP()) } @@ -156,7 +145,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { err := utils.CopyObject(job, &jobVo) if err == nil { // 推送任务到前端 - client := h.clients.Get(data.SessionId) + client := h.mjService.Clients.Get(data.SessionId) if client != nil { utils.ReplyChunkMessage(client, jobVo) } @@ -212,7 +201,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { err := utils.CopyObject(job, &jobVo) if err == nil { // 推送任务到前端 - client := h.clients.Get(data.SessionId) + client := h.mjService.Clients.Get(data.SessionId) if client != nil { utils.ReplyChunkMessage(client, jobVo) } @@ -283,7 +272,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { err := utils.CopyObject(job, &jobVo) if err == nil { // 推送任务到前端 - client := h.clients.Get(data.SessionId) + client := h.mjService.Clients.Get(data.SessionId) if client != nil { utils.ReplyChunkMessage(client, jobVo) } diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index ac79c40d..c0e01f19 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -1,316 +1,202 @@ package handler -// -//import ( -// "chatplus/core" -// "chatplus/core/types" -// "chatplus/service" -// "chatplus/service/oss" -// "chatplus/store/model" -// "chatplus/store/vo" -// "chatplus/utils" -// "chatplus/utils/resp" -// "encoding/base64" -// "fmt" -// "github.com/gin-gonic/gin" -// "github.com/go-redis/redis/v8" -// "github.com/gorilla/websocket" -// "gorm.io/gorm" -// "net/http" -// "strings" -// "sync" -// "time" -//) -// -//type SdJobHandler struct { -// BaseHandler -// redis *redis.Client -// db *gorm.DB -// mjService *service.MjService -// uploaderManager *oss.UploaderManager -// lock sync.Mutex -// clients *types.LMap[string, *types.WsClient] -//} -// -//func NewSdJobHandler( -// app *core.AppServer, -// client *redis.Client, -// db *gorm.DB, -// manager *oss.UploaderManager, -// mjService *service.MjService) *MidJourneyHandler { -// h := MidJourneyHandler{ -// redis: client, -// db: db, -// uploaderManager: manager, -// lock: sync.Mutex{}, -// mjService: mjService, -// clients: types.NewLMap[string, *types.WsClient](), -// } -// h.App = app -// return &h -//} -// -//// Client WebSocket 客户端,用于通知任务状态变更 -//func (h *SdJobHandler) Client(c *gin.Context) { -// ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) -// if err != nil { -// logger.Error(err) -// return -// } -// -// sessionId := c.Query("session_id") -// client := types.NewWsClient(ws) -// // 删除旧的连接 -// h.clients.Delete(sessionId) -// h.clients.Put(sessionId, client) -// logger.Infof("New websocket connected, IP: %s", c.ClientIP()) -//} -// -//type sdNotifyData struct { -// TaskId string -// ImageName string -// ImageData string -// Progress int -// Seed string -// Success bool -// Message string -//} -// -//func (h *SdJobHandler) Notify(c *gin.Context) { -// token := c.GetHeader("Authorization") -// if token != h.App.Config.ExtConfig.Token { -// resp.NotAuth(c) -// return -// } -// var data sdNotifyData -// if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" { -// resp.ERROR(c, types.InvalidArgs) -// return -// } -// logger.Debugf("收到 MidJourney 回调请求:%+v", data) -// -// h.lock.Lock() -// defer h.lock.Unlock() -// -// err, finished := h.notifyHandler(c, data) -// if err != nil { -// resp.ERROR(c, err.Error()) -// return -// } -// -// // 解除任务锁定 -// if finished && (data.Progress == 100) { -// h.redis.Del(c, service.MjRunningJobKey) -// } -// resp.SUCCESS(c) -// -//} -// -//func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) { -// taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() -// if err != nil { // 过期任务,丢弃 -// logger.Warn("任务已过期:", err) -// return nil, true -// } -// -// var task types.SdTask -// err = utils.JsonDecode(taskString, &task) -// if err != nil { // 非标准任务,丢弃 -// logger.Warn("任务解析失败:", err) -// return nil, false -// } -// -// var job model.SdJob -// res := h.db.Where("id = ?", task.Id).First(&job) -// if res.Error != nil { -// logger.Warn("非法任务:", res.Error) -// return nil, false -// } -// job.Params = utils.JsonEncode(task.Params) -// job.ReferenceId = data.ImageData -// job.Progress = data.Progress -// job.Prompt = data.Prompt -// job.Hash = data.Image.Hash -// -// // 任务完成,将最终的图片下载下来 -// if data.Progress == 100 { -// imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) -// if err != nil { -// logger.Error("error with download img: ", err.Error()) -// return err, false -// } -// job.ImgURL = imgURL -// } else { -// // 临时图片直接保存,访问的时候使用代理进行转发 -// job.ImgURL = data.Image.URL -// } -// res = h.db.Updates(&job) -// if res.Error != nil { -// logger.Error("error with update job: ", res.Error) -// return res.Error, false -// } -// -// var jobVo vo.MidJourneyJob -// err := utils.CopyObject(job, &jobVo) -// if err == nil { -// if data.Progress < 100 { -// image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL) -// if err == nil { -// jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) -// } -// } -// -// // 推送任务到前端 -// client := h.clients.Get(task.SessionId) -// if client != nil { -// utils.ReplyChunkMessage(client, jobVo) -// } -// } -// -// // 更新用户剩余绘图次数 -// if data.Progress == 100 { -// h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) -// } -// -// return nil, true -//} -// -//func (h *SdJobHandler) checkLimits(c *gin.Context) bool { -// user, err := utils.GetLoginUser(c, h.db) -// if err != nil { -// resp.NotAuth(c) -// return false -// } -// -// if user.ImgCalls <= 0 { -// resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") -// return false -// } -// -// return true -// -//} -// -//// Image 创建一个绘画任务 -//func (h *SdJobHandler) Image(c *gin.Context) { -// var data struct { -// SessionId string `json:"session_id"` -// Prompt string `json:"prompt"` -// Rate string `json:"rate"` -// Model string `json:"model"` -// Chaos int `json:"chaos"` -// Raw bool `json:"raw"` -// Seed int64 `json:"seed"` -// Stylize int `json:"stylize"` -// Img string `json:"img"` -// Weight float32 `json:"weight"` -// } -// if err := c.ShouldBindJSON(&data); err != nil { -// resp.ERROR(c, types.InvalidArgs) -// return -// } -// if !h.checkLimits(c) { -// return -// } -// -// var prompt = data.Prompt -// if data.Rate != "" && !strings.Contains(prompt, "--ar") { -// prompt += " --ar " + data.Rate -// } -// if data.Seed > 0 && !strings.Contains(prompt, "--seed") { -// prompt += fmt.Sprintf(" --seed %d", data.Seed) -// } -// if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") { -// prompt += fmt.Sprintf(" --s %d", data.Stylize) -// } -// if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { -// prompt += fmt.Sprintf(" --c %d", data.Chaos) -// } -// if data.Img != "" { -// prompt = fmt.Sprintf("%s %s", data.Img, prompt) -// if data.Weight > 0 { -// prompt += fmt.Sprintf(" --iw %f", data.Weight) -// } -// } -// if data.Raw { -// prompt += " --style raw" -// } -// if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { -// prompt += data.Model -// } -// -// idValue, _ := c.Get(types.LoginUserID) -// userId := utils.IntValue(utils.InterfaceToString(idValue), 0) -// job := model.MidJourneyJob{ -// Type: service.Image.String(), -// UserId: userId, -// Progress: 0, -// Prompt: prompt, -// CreatedAt: time.Now(), -// } -// if res := h.db.Create(&job); res.Error != nil { -// resp.ERROR(c, "添加任务失败:"+res.Error.Error()) -// return -// } -// -// h.mjService.PushTask(service.MjTask{ -// Id: int(job.Id), -// SessionId: data.SessionId, -// Src: service.TaskSrcImg, -// Type: service.Image, -// Prompt: prompt, -// UserId: userId, -// }) -// -// var jobVo vo.MidJourneyJob -// err := utils.CopyObject(job, &jobVo) -// if err == nil { -// // 推送任务到前端 -// client := h.clients.Get(data.SessionId) -// if client != nil { -// utils.ReplyChunkMessage(client, jobVo) -// } -// } -// resp.SUCCESS(c) -//} -// -//// JobList 获取 MJ 任务列表 -//func (h *SdJobHandler) JobList(c *gin.Context) { -// status := h.GetInt(c, "status", 0) -// var items []model.MidJourneyJob -// var res *gorm.DB -// userId, _ := c.Get(types.LoginUserID) -// if status == 1 { -// res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items) -// } else { -// res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items) -// } -// if res.Error != nil { -// resp.ERROR(c, types.NoData) -// return -// } -// -// var jobs = make([]vo.MidJourneyJob, 0) -// for _, item := range items { -// var job vo.MidJourneyJob -// err := utils.CopyObject(item, &job) -// if err != nil { -// continue -// } -// if item.Progress < 100 { -// // 30 分钟还没完成的任务直接删除 -// if time.Now().Sub(item.CreatedAt) > time.Minute*30 { -// h.db.Delete(&item) -// continue -// } -// if item.ImgURL != "" { // 正在运行中任务使用代理访问图片 -// image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) -// if err == nil { -// job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) -// } -// } -// } -// jobs = append(jobs, job) -// } -// resp.SUCCESS(c, jobs) -//} +import ( + "chatplus/core" + "chatplus/core/types" + "chatplus/service/sd" + "chatplus/store/model" + "chatplus/store/vo" + "chatplus/utils" + "chatplus/utils/resp" + "encoding/base64" + "fmt" + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "github.com/gorilla/websocket" + "gorm.io/gorm" + "net/http" + "time" +) + +type SdJobHandler struct { + BaseHandler + redis *redis.Client + db *gorm.DB + service *sd.Service +} + +func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler { + h := SdJobHandler{ + redis: redisCli, + db: db, + service: service, + } + h.App = app + return &h +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *SdJobHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + return + } + + sessionId := c.Query("session_id") + client := types.NewWsClient(ws) + // 删除旧的连接 + h.service.Clients.Put(sessionId, client) + logger.Infof("New websocket connected, IP: %s", c.ClientIP()) +} + +func (h *SdJobHandler) checkLimits(c *gin.Context) bool { + user, err := utils.GetLoginUser(c, h.db) + if err != nil { + resp.NotAuth(c) + return false + } + + if user.ImgCalls <= 0 { + resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") + return false + } + + return true + +} + +// Image 创建一个绘画任务 +func (h *SdJobHandler) Image(c *gin.Context) { + if !h.App.Config.SdConfig.Enabled { + resp.ERROR(c, "Stable Diffusion service is disabled") + return + } + + if !h.checkLimits(c) { + return + } + + var data struct { + SessionId string `json:"session_id"` + types.SdTaskParams + } + if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + + if data.Width <= 0 { + data.Width = 512 + } + if data.Height <= 0 { + data.Height = 512 + } + if data.CfgScale <= 0 { + data.CfgScale = 7 + } + if data.Seed == 0 { + data.Seed = -1 + } + if data.Steps <= 0 { + data.Steps = 20 + } + if data.Sampler == "" { + data.Sampler = "Euler a" + } + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + params := types.SdTaskParams{ + TaskId: fmt.Sprintf("task(%s)", utils.RandString(15)), + Prompt: data.Prompt, + NegativePrompt: data.NegativePrompt, + Steps: data.Steps, + Sampler: data.Sampler, + FaceFix: data.FaceFix, + CfgScale: data.CfgScale, + Seed: data.Seed, + Height: data.Height, + Width: data.Width, + HdFix: data.HdFix, + HdRedrawRate: data.HdRedrawRate, + HdScale: data.HdScale, + HdScaleAlg: data.HdScaleAlg, + HdSteps: data.HdSteps, + } + job := model.SdJob{ + UserId: userId, + Type: types.TaskImage.String(), + TaskId: params.TaskId, + Params: utils.JsonEncode(params), + Prompt: data.Prompt, + Progress: 0, + Started: false, + CreatedAt: time.Now(), + } + res := h.db.Create(&job) + if res.Error != nil { + resp.ERROR(c, "error with save job: "+res.Error.Error()) + return + } + + h.service.PushTask(types.SdTask{ + Id: int(job.Id), + SessionId: data.SessionId, + Src: types.TaskSrcImg, + Type: types.TaskImage, + Prompt: data.Prompt, + Params: params, + UserId: userId, + }) + var jobVo vo.SdJob + err := utils.CopyObject(job, &jobVo) + if err == nil { + // 推送任务到前端 + client := h.service.Clients.Get(data.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } + resp.SUCCESS(c) +} + +// JobList 获取 MJ 任务列表 +func (h *SdJobHandler) JobList(c *gin.Context) { + status := h.GetInt(c, "status", 0) + var items []model.SdJob + var res *gorm.DB + userId, _ := c.Get(types.LoginUserID) + if status == 1 { + res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items) + } else { + res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items) + } + if res.Error != nil { + resp.ERROR(c, types.NoData) + return + } + + var jobs = make([]vo.SdJob, 0) + for _, item := range items { + var job vo.SdJob + err := utils.CopyObject(item, &job) + if err != nil { + continue + } + if item.Progress < 100 { + // 30 分钟还没完成的任务直接删除 + if time.Now().Sub(item.CreatedAt) > time.Minute*30 { + h.db.Delete(&item) + continue + } + if item.ImgURL != "" { // 正在运行中任务使用代理访问图片 + image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) + if err == nil { + job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + } + } + jobs = append(jobs, job) + } + resp.SUCCESS(c, jobs) +} diff --git a/api/handler/sms_handler.go b/api/handler/sms_handler.go index fba56a03..b9be4e2b 100644 --- a/api/handler/sms_handler.go +++ b/api/handler/sms_handler.go @@ -66,5 +66,5 @@ type statusVo struct { // Status check if the message service is enabled func (h *SmsHandler) Status(c *gin.Context) { - resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsgService, EnabledRegister: h.App.SysConfig.EnabledRegister}) + resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsg, EnabledRegister: h.App.SysConfig.EnabledRegister}) } diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index a7597ad1..490c833c 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -63,7 +63,7 @@ func (h *UserHandler) Register(c *gin.Context) { // 检查验证码 key := CodeStorePrefix + data.Mobile - if h.App.SysConfig.EnabledMsgService { + if h.App.SysConfig.EnabledMsg { var code int err := h.leveldb.Get(key, &code) if err != nil || code != data.Code { @@ -113,7 +113,7 @@ func (h *UserHandler) Register(c *gin.Context) { return } - if h.App.SysConfig.EnabledMsgService { + if h.App.SysConfig.EnabledMsg { _ = h.leveldb.Delete(key) // 注册成功,删除短信验证码 } resp.SUCCESS(c, user) diff --git a/api/main.go b/api/main.go index d30a5b10..a7be7a56 100644 --- a/api/main.go +++ b/api/main.go @@ -10,6 +10,7 @@ import ( "chatplus/service/fun" "chatplus/service/mj" "chatplus/service/oss" + "chatplus/service/sd" "chatplus/service/wx" "chatplus/store" "context" @@ -121,6 +122,7 @@ func main() { fx.Provide(handler.NewCaptchaHandler), fx.Provide(handler.NewMidJourneyHandler), fx.Provide(handler.NewChatModelHandler), + fx.Provide(handler.NewSdJobHandler), fx.Provide(admin.NewConfigHandler), fx.Provide(admin.NewAdminHandler), @@ -167,6 +169,13 @@ func main() { } }), + // Stable Diffusion 机器人 + fx.Provide(sd.NewService), + fx.Invoke(func(service *sd.Service) { + go func() { + service.Run() + }() + }), // 注册路由 fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { group := s.Engine.Group("/api/role/") @@ -220,6 +229,12 @@ func main() { group.GET("jobs", h.JobList) group.Any("client", h.Client) }), + fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { + group := s.Engine.Group("/api/sd") + group.POST("image", h.Image) + group.GET("jobs", h.JobList) + group.Any("client", h.Client) + }), // 管理后台控制器 fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 25c147bf..9ab320a2 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -20,7 +20,7 @@ import ( const RunningJobKey = "MidJourney_Running_Job" type Service struct { - client *Client + client *Client // MJ 客户端 taskQueue *store.RedisQueue redis *redis.Client db *gorm.DB @@ -128,7 +128,7 @@ func (s *Service) Notify(data CBReq) { // 任务完成,将最终的图片下载下来 if data.Progress == 100 { - imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL) + imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) if err != nil { logger.Error("error with download img: ", err.Error()) return @@ -169,7 +169,7 @@ func (s *Service) Notify(data CBReq) { utils.ReplyMessage(wsClient, content) } // download image - imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL) + imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) if err != nil { logger.Error("error with download image: ", err) if wsClient != nil && data.ReferenceId != "" { diff --git a/api/service/oss/aliyun_oss.go b/api/service/oss/aliyun_oss.go index 4dfb089c..d24d3a65 100644 --- a/api/service/oss/aliyun_oss.go +++ b/api/service/oss/aliyun_oss.go @@ -63,8 +63,14 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) { return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil } -func (s AliYunOss) PutImg(imageURL string) (string, error) { - imageData, err := utils.DownloadImage(imageURL, s.proxyURL) +func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) { + var imageData []byte + var err error + if useProxy { + imageData, err = utils.DownloadImage(imageURL, s.proxyURL) + } else { + imageData, err = utils.DownloadImage(imageURL, "") + } if err != nil { return "", fmt.Errorf("error with download image: %v", err) } diff --git a/api/service/oss/localstorage.go b/api/service/oss/localstorage.go index e0d19d79..a73aa2f7 100644 --- a/api/service/oss/localstorage.go +++ b/api/service/oss/localstorage.go @@ -41,14 +41,18 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (string, error) { return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil } -func (s LocalStorage) PutImg(imageURL string) (string, error) { +func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { filename := filepath.Base(imageURL) filePath, err := utils.GenUploadPath(s.config.BasePath, filename) if err != nil { return "", fmt.Errorf("error with generate image dir: %v", err) } - err = utils.DownloadFile(imageURL, filePath, s.proxyURL) + if useProxy { + err = utils.DownloadFile(imageURL, filePath, s.proxyURL) + } else { + err = utils.DownloadFile(imageURL, filePath, "") + } if err != nil { return "", fmt.Errorf("error with download image: %v", err) } diff --git a/api/service/oss/minio_oss.go b/api/service/oss/minio_oss.go index ee859540..186340cf 100644 --- a/api/service/oss/minio_oss.go +++ b/api/service/oss/minio_oss.go @@ -31,8 +31,14 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) { return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil } -func (s MiniOss) PutImg(imageURL string) (string, error) { - imageData, err := utils.DownloadImage(imageURL, s.proxyURL) +func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) { + var imageData []byte + var err error + if useProxy { + imageData, err = utils.DownloadImage(imageURL, s.proxyURL) + } else { + imageData, err = utils.DownloadImage(imageURL, "") + } if err != nil { return "", fmt.Errorf("error with download image: %v", err) } diff --git a/api/service/oss/qiniu_oss.go b/api/service/oss/qiniu_oss.go index a4f8dd62..266f1f52 100644 --- a/api/service/oss/qiniu_oss.go +++ b/api/service/oss/qiniu_oss.go @@ -72,8 +72,14 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) { return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil } -func (s QinNiuOss) PutImg(imageURL string) (string, error) { - imageData, err := utils.DownloadImage(imageURL, s.proxyURL) +func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { + var imageData []byte + var err error + if useProxy { + imageData, err = utils.DownloadImage(imageURL, s.proxyURL) + } else { + imageData, err = utils.DownloadImage(imageURL, "") + } if err != nil { return "", fmt.Errorf("error with download image: %v", err) } diff --git a/api/service/oss/uploader.go b/api/service/oss/uploader.go index 1c8df952..f484c467 100644 --- a/api/service/oss/uploader.go +++ b/api/service/oss/uploader.go @@ -4,6 +4,6 @@ import "github.com/gin-gonic/gin" type Uploader interface { PutFile(ctx *gin.Context, name string) (string, error) - PutImg(imageURL string) (string, error) + PutImg(imageURL string, useProxy bool) (string, error) Delete(fileURL string) error } diff --git a/api/service/sd/client.go b/api/service/sd/client.go deleted file mode 100644 index c2abe021..00000000 --- a/api/service/sd/client.go +++ /dev/null @@ -1,169 +0,0 @@ -package sd - -import ( - "chatplus/core/types" - "chatplus/utils" - "fmt" - "github.com/imroc/req/v3" - "io" - "time" -) - -type Client struct { - httpClient *req.Client - config *types.StableDiffusionConfig -} - -func NewSdClient(config *types.AppConfig) *Client { - return &Client{ - config: &config.SdConfig, - httpClient: req.C(), - } -} - -func (c *Client) Txt2Img(params types.SdTaskParams) error { - var data []interface{} - err := utils.JsonDecode(Text2ImgParamTemplate, &data) - if err != nil { - return err - } - data[ParamKeys["task_id"]] = params.TaskId - data[ParamKeys["prompt"]] = params.Prompt - data[ParamKeys["negative_prompt"]] = params.NegativePrompt - data[ParamKeys["steps"]] = params.Steps - data[ParamKeys["sampler"]] = params.Sampler - data[ParamKeys["face_fix"]] = params.FaceFix - data[ParamKeys["cfg_scale"]] = params.CfgScale - data[ParamKeys["seed"]] = params.Seed - data[ParamKeys["height"]] = params.Height - data[ParamKeys["width"]] = params.Width - data[ParamKeys["hd_fix"]] = params.HdFix - data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate - data[ParamKeys["hd_scale"]] = params.HdScale - data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg - data[ParamKeys["hd_sample_num"]] = params.HdSampleNum - task := TaskInfo{ - TaskId: params.TaskId, - Data: data, - EventData: nil, - FnIndex: 494, - SessionHash: "ycaxgzm9ah", - } - - go func() { - c.runTask(task, c.httpClient) - }() - return nil -} - -func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) { - body := map[string]any{ - "data": taskInfo.Data, - "event_data": taskInfo.EventData, - "fn_index": taskInfo.FnIndex, - "session_hash": taskInfo.SessionHash, - } - - var result = make(chan CBReq) - go func() { - var res struct { - Data []interface{} `json:"data"` - IsGenerating bool `json:"is_generating"` - Duration float64 `json:"duration"` - AverageDuration float64 `json:"average_duration"` - } - var cbReq = CBReq{TaskId: taskInfo.TaskId} - response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict") - if err != nil { - cbReq.Message = "error with send request: " + err.Error() - cbReq.Success = false - result <- cbReq - return - } - - if response.IsErrorState() { - bytes, _ := io.ReadAll(response.Body) - cbReq.Message = "error http status code: " + string(bytes) - cbReq.Success = false - result <- cbReq - return - } - - var images []struct { - Name string `json:"name"` - Data interface{} `json:"data"` - IsFile bool `json:"is_file"` - } - err = utils.ForceCovert(res.Data[0], &images) - if err != nil { - cbReq.Message = "error with decode image:" + err.Error() - cbReq.Success = false - result <- cbReq - return - } - - var info map[string]any - err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info) - if err != nil { - cbReq.Message = err.Error() - cbReq.Success = false - result <- cbReq - return - } - - //for k, v := range info { - // fmt.Println(k, " => ", v) - //} - cbReq.ImageName = images[0].Name - cbReq.Seed = utils.InterfaceToString(info["seed"]) - cbReq.Success = true - cbReq.Progress = 100 - result <- cbReq - close(result) - - }() - - for { - select { - case value := <-result: - if value.Success { - logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName) - } - return - default: - var progressReq = map[string]any{ - "id_task": taskInfo.TaskId, - "id_live_preview": 1, - } - - var progressRes struct { - Active bool `json:"active"` - Queued bool `json:"queued"` - Completed bool `json:"completed"` - Progress float64 `json:"progress"` - Eta float64 `json:"eta"` - LivePreview string `json:"live_preview"` - IDLivePreview int `json:"id_live_preview"` - TextInfo interface{} `json:"textinfo"` - } - response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress") - var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true} - if err != nil { // TODO: 这里可以考虑设置失败重试次数 - logger.Error(err) - return - } - - if response.IsErrorState() { - bytes, _ := io.ReadAll(response.Body) - logger.Error(string(bytes)) - return - } - - cbReq.ImageData = progressRes.LivePreview - cbReq.Progress = int(progressRes.Progress * 100) - fmt.Println("Progress: ", progressRes.Progress) - fmt.Println("Image: ", progressRes.LivePreview) - time.Sleep(time.Second) - } - } -} diff --git a/api/service/sd/sd_service.go b/api/service/sd/sd_service.go deleted file mode 100644 index da1cc980..00000000 --- a/api/service/sd/sd_service.go +++ /dev/null @@ -1,72 +0,0 @@ -package sd - -import ( - "chatplus/core/types" - "chatplus/service/mj" - "chatplus/store" - "chatplus/store/model" - "chatplus/utils" - "context" - "github.com/go-redis/redis/v8" - "gorm.io/gorm" - "time" -) - -// SD 绘画服务 - -const RunningJobKey = "StableDiffusion_Running_Job" - -type Service struct { - taskQueue *store.RedisQueue - redis *redis.Client - db *gorm.DB - Client *Client -} - -func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service { - return &Service{ - redis: redisCli, - db: db, - Client: client, - taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli), - } -} - -func (s *Service) Run() { - logger.Info("Starting StableDiffusion job consumer.") - ctx := context.Background() - for { - _, err := s.redis.Get(ctx, RunningJobKey).Result() - if err == nil { // 队列串行执行 - time.Sleep(time.Second * 3) - continue - } - var task types.SdTask - err = s.taskQueue.LPop(&task) - if err != nil { - logger.Errorf("taking task with error: %v", err) - continue - } - logger.Infof("Consuming Task: %+v", task) - err = s.Client.Txt2Img(task.Params) - if err != nil { - logger.Error("绘画任务执行失败:", err) - if task.RetryCount <= 5 { - s.taskQueue.RPush(task) - } - task.RetryCount += 1 - time.Sleep(time.Second * 3) - continue - } - - // 更新任务的执行状态 - s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) - // 锁定任务执行通道,直到任务超时(5分钟) - s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5) - } -} - -func (s *Service) PushTask(task types.SdTask) { - logger.Infof("add a new MidJourney Task: %+v", task) - s.taskQueue.RPush(task) -} diff --git a/api/service/sd/service.go b/api/service/sd/service.go new file mode 100644 index 00000000..7fc9943c --- /dev/null +++ b/api/service/sd/service.go @@ -0,0 +1,300 @@ +package sd + +import ( + "chatplus/core/types" + "chatplus/service/oss" + "chatplus/store" + "chatplus/store/model" + "chatplus/store/vo" + "chatplus/utils" + "context" + "fmt" + "github.com/go-redis/redis/v8" + "github.com/imroc/req/v3" + "gorm.io/gorm" + "io" + "strconv" + "time" +) + +// SD 绘画服务 + +const RunningJobKey = "StableDiffusion_Running_Job" + +type Service struct { + httpClient *req.Client + config *types.StableDiffusionConfig + taskQueue *store.RedisQueue + redis *redis.Client + db *gorm.DB + uploadManager *oss.UploaderManager + Clients *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池 +} + +func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service { + return &Service{ + config: &config.SdConfig, + httpClient: req.C(), + redis: redisCli, + db: db, + uploadManager: manager, + Clients: types.NewLMap[string, *types.WsClient](), + taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli), + } +} + +func (s *Service) Run() { + logger.Info("Starting StableDiffusion job consumer.") + ctx := context.Background() + for { + _, err := s.redis.Get(ctx, RunningJobKey).Result() + if err == nil { // 队列串行执行 + time.Sleep(time.Second * 3) + continue + } + var task types.SdTask + err = s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + logger.Infof("Consuming Task: %+v", task) + err = s.Txt2Img(task) + if err != nil { + logger.Error("绘画任务执行失败:", err) + if task.RetryCount <= 5 { + s.taskQueue.RPush(task) + } + task.RetryCount += 1 + time.Sleep(time.Second * 3) + continue + } + + // 更新任务的执行状态 + s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) + // 锁定任务执行通道,直到任务超时(5分钟) + s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5) + } +} + +// PushTask 推送任务到队列 +func (s *Service) PushTask(task types.SdTask) { + logger.Infof("add a new MidJourney Task: %+v", task) + s.taskQueue.RPush(task) +} + +// Txt2Img 文生图 API +func (s *Service) Txt2Img(task types.SdTask) error { + var data []interface{} + err := utils.JsonDecode(Text2ImgParamTemplate, &data) + if err != nil { + return err + } + params := task.Params + data[ParamKeys["task_id"]] = params.TaskId + data[ParamKeys["prompt"]] = params.Prompt + data[ParamKeys["negative_prompt"]] = params.NegativePrompt + data[ParamKeys["steps"]] = params.Steps + data[ParamKeys["sampler"]] = params.Sampler + data[ParamKeys["face_fix"]] = params.FaceFix + data[ParamKeys["cfg_scale"]] = params.CfgScale + data[ParamKeys["seed"]] = params.Seed + data[ParamKeys["height"]] = params.Height + data[ParamKeys["width"]] = params.Width + data[ParamKeys["hd_fix"]] = params.HdFix + data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate + data[ParamKeys["hd_scale"]] = params.HdScale + data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg + data[ParamKeys["hd_sample_num"]] = params.HdSteps + + go func() { + s.runTask(TaskInfo{ + SessionId: task.SessionId, + JobId: task.Id, + TaskId: params.TaskId, + Data: data, + EventData: nil, + FnIndex: 405, + SessionHash: "ycaxgzm9ah", + }, s.httpClient) + }() + return nil +} + +// 执行任务 +func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { + body := map[string]any{ + "data": taskInfo.Data, + "event_data": taskInfo.EventData, + "fn_index": taskInfo.FnIndex, + "session_hash": taskInfo.SessionHash, + } + logger.Debug(utils.JsonEncode(body)) + var result = make(chan CBReq) + go func() { + var res struct { + Data []interface{} `json:"data"` + IsGenerating bool `json:"is_generating"` + Duration float64 `json:"duration"` + AverageDuration float64 `json:"average_duration"` + } + var cbReq = CBReq{TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} + response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict") + if err != nil { + cbReq.Message = "error with send request: " + err.Error() + cbReq.Success = false + result <- cbReq + return + } + + if response.IsErrorState() { + bytes, _ := io.ReadAll(response.Body) + cbReq.Message = "error http status code: " + string(bytes) + cbReq.Success = false + result <- cbReq + return + } + + var images []struct { + Name string `json:"name"` + Data interface{} `json:"data"` + IsFile bool `json:"is_file"` + } + err = utils.ForceCovert(res.Data[0], &images) + if err != nil { + cbReq.Message = "error with decode image:" + err.Error() + cbReq.Success = false + result <- cbReq + return + } + + var info map[string]any + err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info) + if err != nil { + cbReq.Message = err.Error() + cbReq.Success = false + result <- cbReq + return + } + + //for k, v := range info { + // fmt.Println(k, " => ", v) + //} + cbReq.ImageName = images[0].Name + seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64) + cbReq.Seed = seed + cbReq.Success = true + cbReq.Progress = 100 + result <- cbReq + close(result) + + }() + + for { + select { + case value := <-result: + s.callback(value) + return + default: + var progressReq = map[string]any{ + "id_task": taskInfo.TaskId, + "id_live_preview": 1, + } + + var progressRes struct { + Active bool `json:"active"` + Queued bool `json:"queued"` + Completed bool `json:"completed"` + Progress float64 `json:"progress"` + Eta float64 `json:"eta"` + LivePreview string `json:"live_preview"` + IDLivePreview int `json:"id_live_preview"` + TextInfo interface{} `json:"textinfo"` + } + response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress") + var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} + if err != nil { // TODO: 这里可以考虑设置失败重试次数 + logger.Error(err) + return + } + + if response.IsErrorState() { + bytes, _ := io.ReadAll(response.Body) + logger.Error(string(bytes)) + return + } + + cbReq.ImageData = progressRes.LivePreview + cbReq.Progress = int(progressRes.Progress * 100) + s.callback(cbReq) + time.Sleep(time.Second) + } + } +} + +func (s *Service) callback(data CBReq) { + client := s.Clients.Get(data.SessionId) + if data.Success { // 任务成功 + var job model.SdJob + res := s.db.Where("id = ?", data.JobId).First(&job) + if res.Error != nil { + logger.Warn("非法任务:", res.Error) + return + } + // 更新任务进度 + job.Progress = data.Progress + // 更新任务 seed + var params types.SdTaskParams + err := utils.JsonDecode(job.Params, ¶ms) + if err != nil { + logger.Error("任务解析失败:", err) + return + } + + params.Seed = data.Seed + if data.ImageName != "" { // 下载图片 + imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName) + imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false) + if err != nil { + logger.Error("error with download img: ", err.Error()) + return + } + job.ImgURL = imageURL + } + + res = s.db.Updates(&job) + if res.Error != nil { + logger.Error("error with update job: ", res.Error) + return + } + + var jobVo vo.SdJob + err = utils.CopyObject(job, &jobVo) + if err != nil { + logger.Error("error with copy object: ", err) + return + } + + if data.Progress < 100 { + logger.Infof(data.ImageData) + jobVo.ImgURL = data.ImageData + } + + // 推送任务到前端 + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } else { // 任务失败 + logger.Error("任务执行失败:", data.Message) + // 删除任务 + s.db.Delete(&model.SdJob{Id: uint(data.JobId)}) + // 推送消息到前端 + if client != nil { + utils.ReplyChunkMessage(client, vo.SdJob{ + Id: uint(data.JobId), + Progress: -1, + Prompt: fmt.Sprintf("任务[%s]执行失败,已删除!", data.TaskId), + }) + } + } +} diff --git a/api/service/sd/types.go b/api/service/sd/types.go index b0942b89..8a052b1d 100644 --- a/api/service/sd/types.go +++ b/api/service/sd/types.go @@ -5,19 +5,23 @@ import logger2 "chatplus/logger" var logger = logger2.GetLogger() type TaskInfo struct { - TaskId string `json:"task_id"` - Data interface{} `json:"data"` - EventData interface{} `json:"event_data"` - FnIndex int `json:"fn_index"` - SessionHash string `json:"session_hash"` + SessionId string + JobId int + TaskId string + Data []interface{} + EventData interface{} + FnIndex int + SessionHash string } type CBReq struct { + SessionId string + JobId int TaskId string ImageName string ImageData string Progress int - Seed string + Seed int64 Success bool Message string } @@ -41,164 +45,170 @@ var ParamKeys = map[string]int{ } const Text2ImgParamTemplate = `[ -"", -"", +"task(p1lk3n41saygmr8)", +"a tiger sit on the window", "", [], -30, -"DPM++ SDE Karras", +20, +"Euler a", false, false, 1, 1, -7.5, +7, -1, -1, 0, 0, 0, false, -512, -512, -true, +128, +128, +false, 0.7, 2, "Latent", -10, 0, 0, -"Use same sampler", -"", -"", +0, [], "None", false, "MultiDiffusion", false, +10, +1, +1, +64, +false, true, 1024, 1024, 96, 96, 48, -4, +1, "None", 2, false, -10, +false, +false, +false, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +false, +true, +true, +false, +1536, +96, +false, +false, +"LoRA", +"None", 1, 1, -64, -false, -false, -false, -false, -false, -0.4, -0.4, -0.2, -0.2, -"", -"", -"Background", -0.2, --1, -false, -0.4, -0.4, -0.2, -0.2, -"", -"", -"Background", -0.2, --1, -false, -0.4, -0.4, -0.2, -0.2, -"", -"", -"Background", -0.2, --1, -false, -0.4, -0.4, -0.2, -0.2, -"", -"", -"Background", -0.2, --1, -false, -0.4, -0.4, -0.2, -0.2, -"", -"", -"Background", -0.2, --1, -false, -0.4, -0.4, -0.2, -0.2, -"", -"", -"Background", -0.2, --1, -false, -0.4, -0.4, -0.2, -0.2, -"", -"", -"Background", -0.2, --1, -false, -0.4, -0.4, -0.2, -0.2, -"", -"", -"Background", -0.2, --1, -false, -3072, -192, -true, -true, -true, -false, +"LoRA", +"None", +1, +1, +"LoRA", +"None", +1, +1, +"LoRA", +"None", +1, +1, +"LoRA", +"None", +1, +1, +null, +"Refresh models", +null, null, null, null, -false, -"", -0.5, -true, -false, -"", -"Lerp", -false, -"🔄", -false, -false, -false, -false, -false, -false, -false, false, false, "positive", @@ -209,26 +219,26 @@ false, "", "Seed", "", -[], "Nothing", "", -[], "Nothing", "", -[], true, false, false, false, 0, null, +false, null, false, null, -null, false, null, -null, false, -50 +50, +[], +"", +"", +"" ]` diff --git a/web/src/assets/css/image-sd.css b/web/src/assets/css/image-sd.css new file mode 100644 index 00000000..e702499e --- /dev/null +++ b/web/src/assets/css/image-sd.css @@ -0,0 +1,187 @@ +.page-sd { + background-color: #282c34; +} +.page-sd .inner { + display: flex; +/* 修改滚动条的颜色 */ +/* 修改滚动条轨道的背景颜色 */ +/* 修改滚动条的滑块颜色 */ +/* 修改滚动条的滑块的悬停颜色 */ +} +.page-sd .inner .sd-box { + margin: 10px; + background-color: #262626; + border: 1px solid #454545; + min-width: 300px; + max-width: 300px; + padding: 10px; + border-radius: 10px; + color: #fff; + font-size: 14px; +} +.page-sd .inner .sd-box h2 { + font-weight: bold; + font-size: 20px; + text-align: center; + color: #47fff1; +} +.page-sd .inner .sd-box ::-webkit-scrollbar { + width: 0; + height: 0; + background-color: transparent; +} +.page-sd .inner .sd-box .sd-params { + margin-top: 10px; + overflow: auto; +} +.page-sd .inner .sd-box .sd-params .param-line { + padding: 0 10px; +} +.page-sd .inner .sd-box .sd-params .param-line .el-icon { + position: relative; + top: 3px; +} +.page-sd .inner .sd-box .sd-params .param-line .el-input__suffix-inner .el-icon { + top: 0; +} +.page-sd .inner .sd-box .sd-params .param-line .grid-content, +.page-sd .inner .sd-box .sd-params .param-line .form-item-inner { + display: flex; +} +.page-sd .inner .sd-box .sd-params .param-line .grid-content .el-icon, +.page-sd .inner .sd-box .sd-params .param-line .form-item-inner .el-icon { + margin-left: 10px; + margin-top: 2px; +} +.page-sd .inner .sd-box .sd-params .param-line.pt { + padding-top: 5px; + padding-bottom: 5px; +} +.page-sd .inner .sd-box .submit-btn { + padding: 10px 15px 0 15px; + text-align: center; +} +.page-sd .inner .sd-box .submit-btn .el-button { + width: 100%; +} +.page-sd .inner .sd-box .submit-btn .el-button span { + color: #2d3a4b; +} +.page-sd .inner .el-form .el-form-item__label { + color: #fff; +} +.page-sd .inner ::-webkit-scrollbar { + width: 10px; /* 滚动条宽度 */ +} +.page-sd .inner ::-webkit-scrollbar-track { + background-color: #282c34; +} +.page-sd .inner ::-webkit-scrollbar-thumb { + background-color: #444; + border-radius: 10px; +} +.page-sd .inner ::-webkit-scrollbar-thumb:hover { + background-color: #666; +} +.page-sd .inner .task-list-box { + width: 100%; + padding: 10px; + color: #fff; + overflow-x: hidden; +} +.page-sd .inner .task-list-box .running-job-list .job-item { + width: 100%; + padding: 2px; + background-color: #555; +} +.page-sd .inner .task-list-box .running-job-list .job-item .job-item-inner { + position: relative; + height: 100%; + overflow: hidden; +} +.page-sd .inner .task-list-box .running-job-list .job-item .job-item-inner .progress { + position: absolute; + width: 100%; + height: 100%; + top: 0; + left: 0; + display: flex; + justify-content: center; + align-items: center; +} +.page-sd .inner .task-list-box .running-job-list .job-item .job-item-inner .progress span { + font-size: 20px; + color: #fff; +} +.page-sd .inner .task-list-box .finish-job-list .job-item { + width: 100%; + height: 100%; +} +.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line { + margin: 6px 0; +} +.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul { + display: flex; + flex-flow: row; +} +.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li { + margin-right: 10px; +} +.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li a { + padding: 3px 0; + width: 44px; + text-align: center; + border-radius: 5px; + display: block; + cursor: pointer; + background-color: #4e5058; + color: #fff; +} +.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li a:hover { + background-color: #6d6f78; +} +.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul .show-prompt { + font-size: 20px; + cursor: pointer; +} +.page-sd .inner .task-list-box .el-image { + width: 100%; + height: 100%; + max-height: 240px; +} +.page-sd .inner .task-list-box .el-image img { + height: 240px; +} +.page-sd .inner .task-list-box .el-image .el-image-viewer__wrapper img { + width: auto; + height: auto; +} +.page-sd .inner .task-list-box .el-image .image-slot { + display: flex; + flex-flow: column; + justify-content: center; + align-items: center; + height: 100%; + min-height: 200px; + color: #fff; +} +.page-sd .inner .task-list-box .el-image .image-slot .iconfont { + font-size: 50px; + margin-bottom: 10px; +} +.page-sd .inner .task-list-box .el-image.upscale { + max-height: 304px; +} +.page-sd .inner .task-list-box .el-image.upscale img { + height: 304px; +} +.page-sd .inner .task-list-box .el-image.upscale .el-image-viewer__wrapper img { + width: auto; + height: auto; +} +.mj-list-item-prompt .el-icon { + margin-left: 10px; + cursor: pointer; + position: relative; + top: 2px; +} diff --git a/web/src/assets/css/image-sd.styl b/web/src/assets/css/image-sd.styl new file mode 100644 index 00000000..7c740d0f --- /dev/null +++ b/web/src/assets/css/image-sd.styl @@ -0,0 +1,255 @@ +.page-sd { + background-color: #282c34; + + .inner { + display: flex; + + .sd-box { + margin 10px + background-color #262626 + border 1px solid #454545 + min-width 300px + max-width 300px + padding 10px + border-radius 10px + color #ffffff; + font-size 14px + + h2 { + font-weight: bold; + font-size 20px + text-align center + color #47fff1 + } + + // 隐藏滚动条 + + ::-webkit-scrollbar { + width: 0; + height: 0; + background-color: transparent; + } + + .sd-params { + margin-top 10px + overflow auto + + + .param-line { + padding 0 10px + + .el-icon { + position relative + top 3px + } + + .el-input__suffix-inner { + .el-icon { + top 0 + } + } + + .grid-content + .form-item-inner { + display flex + + .el-icon { + margin-left 10px + margin-top 2px + } + } + + } + + .param-line.pt { + padding-top 5px + padding-bottom 5px + } + } + + .submit-btn { + padding 10px 15px 0 15px + text-align center + + .el-button { + width 100% + + span { + color #2D3A4B + } + } + } + } + + .el-form { + .el-form-item__label { + color #ffffff + } + } + + /* 修改滚动条的颜色 */ + + ::-webkit-scrollbar { + width: 10px; /* 滚动条宽度 */ + } + + /* 修改滚动条轨道的背景颜色 */ + + ::-webkit-scrollbar-track { + background-color: #282C34; + } + + /* 修改滚动条的滑块颜色 */ + + ::-webkit-scrollbar-thumb { + background-color: #444444; + border-radius 10px + } + + /* 修改滚动条的滑块的悬停颜色 */ + + ::-webkit-scrollbar-thumb:hover { + background-color: #666666; + } + + .task-list-box { + width 100% + padding 10px + color #ffffff + overflow-x hidden + + .running-job-list { + .job-item { + //border: 1px solid #454545; + width: 100%; + padding 2px + background-color #555555 + + .job-item-inner { + position relative + height 100% + overflow hidden + + .progress { + position absolute + width 100% + height 100% + top 0 + left 0 + display flex + justify-content center + align-items center + + span { + font-size 20px + color #ffffff + } + } + } + } + } + + + .finish-job-list { + .job-item { + width 100% + height 100% + + .opt { + .opt-line { + margin 6px 0 + + ul { + display flex + flex-flow row + + li { + margin-right 10px + + a { + padding 3px 0 + width 44px + text-align center + border-radius 5px + display block + cursor pointer + background-color #4E5058 + color #ffffff + + &:hover { + background-color #6D6F78 + } + } + } + + .show-prompt { + font-size 20px + cursor pointer + } + } + } + } + + } + + } + + .el-image { + width 100% + height 100% + max-height 240px + + img { + height 240px + } + + .el-image-viewer__wrapper { + img { + width auto + height auto + } + } + + .image-slot { + display flex + flex-flow column + justify-content center + align-items center + height 100% + min-height 200px + color #ffffff + + .iconfont { + font-size 50px + margin-bottom 10px + } + } + } + + .el-image.upscale { + max-height 304px + + img { + height 304px + } + + .el-image-viewer__wrapper { + img { + width auto + height auto + } + } + } + } + } + +} + +.mj-list-item-prompt { + .el-icon { + margin-left 10px + cursor pointer + position relative + top 2px + } +} \ No newline at end of file diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index aeec7b6a..8f8cf33a 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -230,13 +230,13 @@ placement="top-start" :title="getTaskType(scope.item.type)" :width="240" - trigger="click" + trigger="hover" >