opt: add sessionId for mj task

This commit is contained in:
RockYang
2023-09-19 18:15:08 +08:00
parent 2a71c2b0e7
commit b4b9df81cb
8 changed files with 234 additions and 82 deletions

View File

@@ -13,7 +13,9 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
"sync"
"time"
@@ -43,6 +45,7 @@ type MidJourneyHandler struct {
mjService *service.MjService
uploaderManager *oss.UploaderManager
lock sync.Mutex
clients *types.LMap[string, *types.WsClient]
}
func NewMidJourneyHandler(
@@ -57,6 +60,7 @@ func NewMidJourneyHandler(
uploaderManager: manager,
lock: sync.Mutex{},
mjService: mjService,
clients: types.NewLMap[string, *types.WsClient](),
}
h.App = app
return &h
@@ -72,6 +76,23 @@ type notifyData struct {
Progress int `json:"progress"`
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
// 关闭旧的连接
if h.clients.Has(sessionId) {
h.clients.Get(sessionId).Close()
}
h.clients.Put(sessionId, client)
}
func (h *MidJourneyHandler) Notify(c *gin.Context) {
token := c.GetHeader("Authorization")
if token != h.App.Config.ExtConfig.Token {
@@ -154,8 +175,23 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
return res.Error, false
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
if err == nil {
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
// 推送任务到前端
client := h.clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
} else if task.Src == service.TaskSrcChat { // 聊天任务
wsClient := h.App.MjTaskClients.Get(task.Id)
wsClient := h.App.MjTaskClients.Get(task.SessionId)
if data.Status == Finished {
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
@@ -216,7 +252,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
// 本次绘画完毕,移除客户端
h.App.MjTaskClients.Delete(task.Id)
h.App.MjTaskClients.Delete(task.SessionId)
} else {
// 使用代理临时转发图片
if data.Image.URL != "" {
@@ -235,15 +271,16 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
// Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
Img string `json:"img"`
Weight float32 `json:"weight"`
SessionId string `json:"session_id"`
Prompt string `json:"prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
Img string `json:"img"`
Weight float32 `json:"weight"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -268,6 +305,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
prompt += fmt.Sprintf(" --iw %f", data.Weight)
}
}
if data.Raw {
prompt += " --style raw"
}
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
prompt += data.Model
}
@@ -287,12 +327,23 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
}
h.mjService.PushTask(service.MjTask{
Id: fmt.Sprintf("%d", job.Id),
Src: service.TaskSrcImg,
Type: service.Image,
Prompt: prompt,
UserId: userId,
Id: int(job.Id),
SessionId: data.SessionId,
Src: service.TaskSrcImg,
Type: service.Image,
Prompt: prompt,
UserId: userId,
})
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
resp.SUCCESS(c)
}
@@ -317,7 +368,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
jobId := data.SessionId
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := service.TaskSrc(data.Src)
if src == service.TaskSrcImg {
@@ -330,14 +381,25 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error == nil {
jobId = fmt.Sprintf("%d", job.Id)
jobId = int(job.Id)
} else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
}
h.mjService.PushTask(service.MjTask{
Id: jobId,
SessionId: data.SessionId,
Src: src,
Type: service.Upscale,
Prompt: data.Prompt,
@@ -358,6 +420,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
h.App.MjTaskClients.Put(data.SessionId, wsClient)
}
}
resp.SUCCESS(c)
}
@@ -370,7 +433,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
jobId := data.SessionId
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := service.TaskSrc(data.Src)
if src == service.TaskSrcImg {
@@ -384,14 +447,25 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error == nil {
jobId = fmt.Sprintf("%d", job.Id)
jobId = int(job.Id)
} else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
}
h.mjService.PushTask(service.MjTask{
Id: jobId,
SessionId: data.SessionId,
Src: src,
Type: service.Variation,
Prompt: data.Prompt,