diff --git a/api/core/app_server.go b/api/core/app_server.go index 96747bc3..9da122d4 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -159,7 +159,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { var tokenString string if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API tokenString = c.GetHeader(types.AdminAuthHeader) - } else if c.Request.URL.Path == "/api/chat/new" { + } else if c.Request.URL.Path == "/api/chat/new" || c.Request.URL.Path == "/api/mj/client" { tokenString = c.Query("token") } else { tokenString = c.GetHeader(types.UserAuthHeader) diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index a317d6ce..930c7841 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -86,11 +86,10 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { sessionId := c.Query("session_id") client := types.NewWsClient(ws) - // 关闭旧的连接 - if h.clients.Has(sessionId) { - h.clients.Get(sessionId).Close() - } + // 删除旧的连接 + h.clients.Delete(sessionId) h.clients.Put(sessionId, client) + logger.Infof("New websocket connected, IP: %s", c.ClientIP()) } func (h *MidJourneyHandler) Notify(c *gin.Context) { @@ -265,9 +264,30 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro } } + // 更新用户剩余绘图次数 + if data.Status == Finished && task.Type != service.Upscale { + h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) + } + return nil, true } +func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool { + user, err := utils.GetLoginUser(c, h.db) + if err != nil { + resp.NotAuth(c) + return false + } + + if user.ImgCalls <= 0 { + resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") + return false + } + + return true + +} + // Image 创建一个绘画任务 func (h *MidJourneyHandler) Image(c *gin.Context) { var data struct { @@ -286,6 +306,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } + if h.checkLimits(c) { + return + } + var prompt = data.Prompt if data.Rate != "" && !strings.Contains(prompt, "--ar") { prompt += " --ar " + data.Rate @@ -367,6 +391,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } + if h.checkLimits(c) { + return + } + idValue, _ := c.Get(types.LoginUserID) jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) @@ -432,6 +460,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { return } + if h.checkLimits(c) { + return + } + idValue, _ := c.Get(types.LoginUserID) jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) diff --git a/api/main.go b/api/main.go index d9a2559e..eb6feb77 100644 --- a/api/main.go +++ b/api/main.go @@ -195,6 +195,7 @@ func main() { group.POST("upscale", h.Upscale) group.POST("variation", h.Variation) group.GET("jobs", h.JobList) + group.Any("client", h.Client) }), // 管理后台控制器 diff --git a/web/src/assets/css/image-mj.styl b/web/src/assets/css/image-mj.styl index 6375c4f4..beebbdef 100644 --- a/web/src/assets/css/image-mj.styl +++ b/web/src/assets/css/image-mj.styl @@ -322,6 +322,13 @@ img { height 304px } + + .el-image-viewer__wrapper { + img { + width auto + height auto + } + } } } } diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 5b85afad..cf90e68e 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -285,7 +285,7 @@ placement="top-start" title="提示词" :width="240" - trigger="click" + trigger="hover" >