mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-05 00:33:47 +08:00
fix socket connect for mj task notify
This commit is contained in:
@@ -159,7 +159,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
var tokenString string
|
||||
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
|
||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||
} else if c.Request.URL.Path == "/api/chat/new" {
|
||||
} else if c.Request.URL.Path == "/api/chat/new" || c.Request.URL.Path == "/api/mj/client" {
|
||||
tokenString = c.Query("token")
|
||||
} else {
|
||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||
|
||||
@@ -86,11 +86,10 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||
|
||||
sessionId := c.Query("session_id")
|
||||
client := types.NewWsClient(ws)
|
||||
// 关闭旧的连接
|
||||
if h.clients.Has(sessionId) {
|
||||
h.clients.Get(sessionId).Close()
|
||||
}
|
||||
// 删除旧的连接
|
||||
h.clients.Delete(sessionId)
|
||||
h.clients.Put(sessionId, client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
@@ -265,9 +264,30 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
|
||||
}
|
||||
}
|
||||
|
||||
// 更新用户剩余绘图次数
|
||||
if data.Status == Finished && task.Type != service.Upscale {
|
||||
h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
}
|
||||
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
||||
user, err := utils.GetLoginUser(c, h.db)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return false
|
||||
}
|
||||
|
||||
if user.ImgCalls <= 0 {
|
||||
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
var data struct {
|
||||
@@ -286,6 +306,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if h.checkLimits(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var prompt = data.Prompt
|
||||
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
||||
prompt += " --ar " + data.Rate
|
||||
@@ -367,6 +391,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.checkLimits(c) {
|
||||
return
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
@@ -432,6 +460,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.checkLimits(c) {
|
||||
return
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
|
||||
@@ -195,6 +195,7 @@ func main() {
|
||||
group.POST("upscale", h.Upscale)
|
||||
group.POST("variation", h.Variation)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.Any("client", h.Client)
|
||||
}),
|
||||
|
||||
// 管理后台控制器
|
||||
|
||||
Reference in New Issue
Block a user