package handler import ( "chatplus/core" "chatplus/core/types" "chatplus/service" "chatplus/service/oss" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" "encoding/base64" "fmt" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "github.com/gorilla/websocket" "gorm.io/gorm" "net/http" "strings" "sync" "time" ) type SdJobHandler struct { BaseHandler redis *redis.Client db *gorm.DB mjService *service.MjService uploaderManager *oss.UploaderManager lock sync.Mutex clients *types.LMap[string, *types.WsClient] } func NewSdJobHandler( app *core.AppServer, client *redis.Client, db *gorm.DB, manager *oss.UploaderManager, mjService *service.MjService) *MidJourneyHandler { h := MidJourneyHandler{ redis: client, db: db, uploaderManager: manager, lock: sync.Mutex{}, mjService: mjService, clients: types.NewLMap[string, *types.WsClient](), } h.App = app return &h } // Client WebSocket 客户端,用于通知任务状态变更 func (h *SdJobHandler) Client(c *gin.Context) { ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) if err != nil { logger.Error(err) return } sessionId := c.Query("session_id") client := types.NewWsClient(ws) // 删除旧的连接 h.clients.Delete(sessionId) h.clients.Put(sessionId, client) logger.Infof("New websocket connected, IP: %s", c.ClientIP()) } type sdNotifyData struct { TaskId string ImageName string ImageData string Progress int Seed string Success bool Message string } func (h *SdJobHandler) Notify(c *gin.Context) { token := c.GetHeader("Authorization") if token != h.App.Config.ExtConfig.Token { resp.NotAuth(c) return } var data sdNotifyData if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" { resp.ERROR(c, types.InvalidArgs) return } logger.Debugf("收到 MidJourney 回调请求:%+v", data) h.lock.Lock() defer h.lock.Unlock() err, finished := h.notifyHandler(c, data) if err != nil { resp.ERROR(c, err.Error()) return } // 解除任务锁定 if finished && (data.Progress == 100) { h.redis.Del(c, service.MjRunningJobKey) } resp.SUCCESS(c) } func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) { taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() if err != nil { // 过期任务,丢弃 logger.Warn("任务已过期:", err) return nil, true } var task types.SdTask err = utils.JsonDecode(taskString, &task) if err != nil { // 非标准任务,丢弃 logger.Warn("任务解析失败:", err) return nil, false } var job model.SdJob res := h.db.Where("id = ?", task.Id).First(&job) if res.Error != nil { logger.Warn("非法任务:", res.Error) return nil, false } job.Params = utils.JsonEncode(task.Params) job.ReferenceId = data.ImageData job.Progress = data.Progress job.Prompt = data.Prompt job.Hash = data.Image.Hash // 任务完成,将最终的图片下载下来 if data.Progress == 100 { imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) if err != nil { logger.Error("error with download img: ", err.Error()) return err, false } job.ImgURL = imgURL } else { // 临时图片直接保存,访问的时候使用代理进行转发 job.ImgURL = data.Image.URL } res = h.db.Updates(&job) if res.Error != nil { logger.Error("error with update job: ", res.Error) return res.Error, false } var jobVo vo.MidJourneyJob err := utils.CopyObject(job, &jobVo) if err == nil { if data.Progress < 100 { image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL) if err == nil { jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) } } // 推送任务到前端 client := h.clients.Get(task.SessionId) if client != nil { utils.ReplyChunkMessage(client, jobVo) } } // 更新用户剩余绘图次数 if data.Progress == 100 { h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) } return nil, true } func (h *SdJobHandler) checkLimits(c *gin.Context) bool { user, err := utils.GetLoginUser(c, h.db) if err != nil { resp.NotAuth(c) return false } if user.ImgCalls <= 0 { resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") return false } return true } // Image 创建一个绘画任务 func (h *SdJobHandler) Image(c *gin.Context) { var data struct { SessionId string `json:"session_id"` Prompt string `json:"prompt"` Rate string `json:"rate"` Model string `json:"model"` Chaos int `json:"chaos"` Raw bool `json:"raw"` Seed int64 `json:"seed"` Stylize int `json:"stylize"` Img string `json:"img"` Weight float32 `json:"weight"` } if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) return } if !h.checkLimits(c) { return } var prompt = data.Prompt if data.Rate != "" && !strings.Contains(prompt, "--ar") { prompt += " --ar " + data.Rate } if data.Seed > 0 && !strings.Contains(prompt, "--seed") { prompt += fmt.Sprintf(" --seed %d", data.Seed) } if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") { prompt += fmt.Sprintf(" --s %d", data.Stylize) } if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { prompt += fmt.Sprintf(" --c %d", data.Chaos) } if data.Img != "" { prompt = fmt.Sprintf("%s %s", data.Img, prompt) if data.Weight > 0 { prompt += fmt.Sprintf(" --iw %f", data.Weight) } } if data.Raw { prompt += " --style raw" } if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { prompt += data.Model } idValue, _ := c.Get(types.LoginUserID) userId := utils.IntValue(utils.InterfaceToString(idValue), 0) job := model.MidJourneyJob{ Type: service.Image.String(), UserId: userId, Progress: 0, Prompt: prompt, CreatedAt: time.Now(), } if res := h.db.Create(&job); res.Error != nil { resp.ERROR(c, "添加任务失败:"+res.Error.Error()) return } h.mjService.PushTask(service.MjTask{ Id: int(job.Id), SessionId: data.SessionId, Src: service.TaskSrcImg, Type: service.Image, Prompt: prompt, UserId: userId, }) var jobVo vo.MidJourneyJob err := utils.CopyObject(job, &jobVo) if err == nil { // 推送任务到前端 client := h.clients.Get(data.SessionId) if client != nil { utils.ReplyChunkMessage(client, jobVo) } } resp.SUCCESS(c) } // JobList 获取 MJ 任务列表 func (h *SdJobHandler) JobList(c *gin.Context) { status := h.GetInt(c, "status", 0) var items []model.MidJourneyJob var res *gorm.DB userId, _ := c.Get(types.LoginUserID) if status == 1 { res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items) } else { res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items) } if res.Error != nil { resp.ERROR(c, types.NoData) return } var jobs = make([]vo.MidJourneyJob, 0) for _, item := range items { var job vo.MidJourneyJob err := utils.CopyObject(item, &job) if err != nil { continue } if item.Progress < 100 { // 30 分钟还没完成的任务直接删除 if time.Now().Sub(item.CreatedAt) > time.Minute*30 { h.db.Delete(&item) continue } if item.ImgURL != "" { // 正在运行中任务使用代理访问图片 image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) if err == nil { job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) } } } jobs = append(jobs, job) } resp.SUCCESS(c, jobs) }