diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index ff906260..a317d6ce 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -13,7 +13,9 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" + "github.com/gorilla/websocket" "gorm.io/gorm" + "net/http" "strings" "sync" "time" @@ -43,6 +45,7 @@ type MidJourneyHandler struct { mjService *service.MjService uploaderManager *oss.UploaderManager lock sync.Mutex + clients *types.LMap[string, *types.WsClient] } func NewMidJourneyHandler( @@ -57,6 +60,7 @@ func NewMidJourneyHandler( uploaderManager: manager, lock: sync.Mutex{}, mjService: mjService, + clients: types.NewLMap[string, *types.WsClient](), } h.App = app return &h @@ -72,6 +76,23 @@ type notifyData struct { Progress int `json:"progress"` } +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *MidJourneyHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + return + } + + sessionId := c.Query("session_id") + client := types.NewWsClient(ws) + // 关闭旧的连接 + if h.clients.Has(sessionId) { + h.clients.Get(sessionId).Close() + } + h.clients.Put(sessionId, client) +} + func (h *MidJourneyHandler) Notify(c *gin.Context) { token := c.GetHeader("Authorization") if token != h.App.Config.ExtConfig.Token { @@ -154,8 +175,23 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro return res.Error, false } + var jobVo vo.MidJourneyJob + err := utils.CopyObject(job, &jobVo) + if err == nil { + image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL) + if err == nil { + jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + + // 推送任务到前端 + client := h.clients.Get(task.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } + } else if task.Src == service.TaskSrcChat { // 聊天任务 - wsClient := h.App.MjTaskClients.Get(task.Id) + wsClient := h.App.MjTaskClients.Get(task.SessionId) if data.Status == Finished { if wsClient != nil && data.ReferenceId != "" { content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) @@ -216,7 +252,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) // 本次绘画完毕,移除客户端 - h.App.MjTaskClients.Delete(task.Id) + h.App.MjTaskClients.Delete(task.SessionId) } else { // 使用代理临时转发图片 if data.Image.URL != "" { @@ -235,15 +271,16 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro // Image 创建一个绘画任务 func (h *MidJourneyHandler) Image(c *gin.Context) { var data struct { - Prompt string `json:"prompt"` - Rate string `json:"rate"` - Model string `json:"model"` - Chaos int `json:"chaos"` - Raw bool `json:"raw"` - Seed int64 `json:"seed"` - Stylize int `json:"stylize"` - Img string `json:"img"` - Weight float32 `json:"weight"` + SessionId string `json:"session_id"` + Prompt string `json:"prompt"` + Rate string `json:"rate"` + Model string `json:"model"` + Chaos int `json:"chaos"` + Raw bool `json:"raw"` + Seed int64 `json:"seed"` + Stylize int `json:"stylize"` + Img string `json:"img"` + Weight float32 `json:"weight"` } if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) @@ -268,6 +305,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { prompt += fmt.Sprintf(" --iw %f", data.Weight) } } + if data.Raw { + prompt += " --style raw" + } if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { prompt += data.Model } @@ -287,12 +327,23 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { } h.mjService.PushTask(service.MjTask{ - Id: fmt.Sprintf("%d", job.Id), - Src: service.TaskSrcImg, - Type: service.Image, - Prompt: prompt, - UserId: userId, + Id: int(job.Id), + SessionId: data.SessionId, + Src: service.TaskSrcImg, + Type: service.Image, + Prompt: prompt, + UserId: userId, }) + + var jobVo vo.MidJourneyJob + err := utils.CopyObject(job, &jobVo) + if err == nil { + // 推送任务到前端 + client := h.clients.Get(data.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } resp.SUCCESS(c) } @@ -317,7 +368,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { } idValue, _ := c.Get(types.LoginUserID) - jobId := data.SessionId + jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) src := service.TaskSrc(data.Src) if src == service.TaskSrcImg { @@ -330,14 +381,25 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { CreatedAt: time.Now(), } if res := h.db.Create(&job); res.Error == nil { - jobId = fmt.Sprintf("%d", job.Id) + jobId = int(job.Id) } else { resp.ERROR(c, "添加任务失败:"+res.Error.Error()) return } + + var jobVo vo.MidJourneyJob + err := utils.CopyObject(job, &jobVo) + if err == nil { + // 推送任务到前端 + client := h.clients.Get(data.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } } h.mjService.PushTask(service.MjTask{ Id: jobId, + SessionId: data.SessionId, Src: src, Type: service.Upscale, Prompt: data.Prompt, @@ -358,6 +420,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { h.App.MjTaskClients.Put(data.SessionId, wsClient) } } + resp.SUCCESS(c) } @@ -370,7 +433,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { } idValue, _ := c.Get(types.LoginUserID) - jobId := data.SessionId + jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) src := service.TaskSrc(data.Src) if src == service.TaskSrcImg { @@ -384,14 +447,25 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { CreatedAt: time.Now(), } if res := h.db.Create(&job); res.Error == nil { - jobId = fmt.Sprintf("%d", job.Id) + jobId = int(job.Id) } else { resp.ERROR(c, "添加任务失败:"+res.Error.Error()) return } + + var jobVo vo.MidJourneyJob + err := utils.CopyObject(job, &jobVo) + if err == nil { + // 推送任务到前端 + client := h.clients.Get(data.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } } h.mjService.PushTask(service.MjTask{ Id: jobId, + SessionId: data.SessionId, Src: src, Type: service.Variation, Prompt: data.Prompt, diff --git a/api/service/function/func_mj.go b/api/service/function/func_mj.go index 19c84407..4ecf0f1a 100644 --- a/api/service/function/func_mj.go +++ b/api/service/function/func_mj.go @@ -22,14 +22,14 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { logger.Infof("MJ 绘画参数:%+v", params) prompt := utils.InterfaceToString(params["prompt"]) f.service.PushTask(service.MjTask{ - Id: utils.InterfaceToString(params["session_id"]), - Src: service.TaskSrcChat, - Type: service.Image, - Prompt: prompt, - UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0), - RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0), - Icon: utils.InterfaceToString(params["icon"]), - ChatId: utils.InterfaceToString(params["chat_id"]), + SessionId: utils.InterfaceToString(params["session_id"]), + Src: service.TaskSrcChat, + Type: service.Image, + Prompt: prompt, + UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0), + RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0), + Icon: utils.InterfaceToString(params["icon"]), + ChatId: utils.InterfaceToString(params["chat_id"]), }) return prompt, nil } diff --git a/api/service/mj_service.go b/api/service/mj_service.go index ae23d33e..7b46638a 100644 --- a/api/service/mj_service.go +++ b/api/service/mj_service.go @@ -41,7 +41,8 @@ const ( ) type MjTask struct { - Id string `json:"id"` + Id int `json:"id"` + SessionId string `json:"session_id"` Src TaskSrc `json:"src"` Type TaskType `json:"type"` UserId int `json:"user_id"` diff --git a/api/utils/net.go b/api/utils/net.go index 04e7f737..a45ecb44 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -12,7 +12,7 @@ import ( var logger = logger2.GetLogger() // ReplyChunkMessage 回复客户片段端消息 -func ReplyChunkMessage(client *types.WsClient, message types.WsMessage) { +func ReplyChunkMessage(client *types.WsClient, message interface{}) { msg, err := json.Marshal(message) if err != nil { logger.Errorf("Error for decoding json data: %v", err.Error()) diff --git a/web/src/assets/css/image-mj.css b/web/src/assets/css/image-mj.css index 988ee2f8..8533ac9d 100644 --- a/web/src/assets/css/image-mj.css +++ b/web/src/assets/css/image-mj.css @@ -185,31 +185,15 @@ display: flex; justify-content: center; align-items: center; - background-color: rgba(0,0,0,0.5); } .page-mj .inner .task-list-box .running-job-list .job-item .job-item-inner .progress span { font-size: 20px; color: #fff; } -.page-mj .inner .task-list-box .running-job-list .job-item .el-image { +.page-mj .inner .task-list-box .finish-job-list .job-item { width: 100%; height: 100%; } -.page-mj .inner .task-list-box .running-job-list .job-item .el-image .image-slot { - display: flex; - flex-flow: column; - justify-content: center; - align-items: center; - height: 100%; - color: #fff; -} -.page-mj .inner .task-list-box .running-job-list .job-item .el-image .image-slot .iconfont { - font-size: 50px; - margin-bottom: 10px; -} -.page-mj .inner .task-list-box .finish-job-list .job-item { - margin-bottom: 20px; -} .page-mj .inner .task-list-box .finish-job-list .job-item .opt .opt-line { margin: 6px 0; } @@ -233,6 +217,37 @@ .page-mj .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li a:hover { background-color: #6d6f78; } +.page-mj .inner .task-list-box .el-image { + width: 100%; + height: 100%; + max-height: 240px; +} +.page-mj .inner .task-list-box .el-image img { + height: 240px; +} +.page-mj .inner .task-list-box .el-image .el-image-viewer__wrapper img { + width: auto; + height: auto; +} +.page-mj .inner .task-list-box .el-image .image-slot { + display: flex; + flex-flow: column; + justify-content: center; + align-items: center; + height: 100%; + min-height: 200px; + color: #fff; +} +.page-mj .inner .task-list-box .el-image .image-slot .iconfont { + font-size: 50px; + margin-bottom: 10px; +} +.page-mj .inner .task-list-box .el-image.upscale { + max-height: 304px; +} +.page-mj .inner .task-list-box .el-image.upscale img { + height: 304px; +} .mj-list-item-prompt .el-icon { margin-left: 10px; cursor: pointer; diff --git a/web/src/assets/css/image-mj.styl b/web/src/assets/css/image-mj.styl index 85362067..6375c4f4 100644 --- a/web/src/assets/css/image-mj.styl +++ b/web/src/assets/css/image-mj.styl @@ -247,7 +247,6 @@ .finish-job-list { .job-item { - margin-bottom 20px width 100% height 100% @@ -316,6 +315,14 @@ } } } + + .el-image.upscale { + max-height 304px + + img { + height 304px + } + } } } diff --git a/web/src/utils/http.js b/web/src/utils/http.js index 205fd0b5..dd0974cb 100644 --- a/web/src/utils/http.js +++ b/web/src/utils/http.js @@ -1,7 +1,7 @@ import axios from 'axios' import {getAdminToken, getSessionId, getUserToken} from "@/store/session"; -axios.defaults.timeout = 10000 +axios.defaults.timeout = 30000 axios.defaults.baseURL = process.env.VUE_APP_API_HOST axios.defaults.withCredentials = true; axios.defaults.headers.post['Content-Type'] = 'application/json' diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 48635187..5b85afad 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -285,10 +285,10 @@ placement="top-start" title="提示词" :width="240" - trigger="hover" + trigger="click" >