opt: add sessionId for mj task

This commit is contained in:
RockYang 2023-09-19 18:15:08 +08:00
parent 2a71c2b0e7
commit b4b9df81cb
8 changed files with 234 additions and 82 deletions

View File

@ -13,7 +13,9 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -43,6 +45,7 @@ type MidJourneyHandler struct {
mjService *service.MjService mjService *service.MjService
uploaderManager *oss.UploaderManager uploaderManager *oss.UploaderManager
lock sync.Mutex lock sync.Mutex
clients *types.LMap[string, *types.WsClient]
} }
func NewMidJourneyHandler( func NewMidJourneyHandler(
@ -57,6 +60,7 @@ func NewMidJourneyHandler(
uploaderManager: manager, uploaderManager: manager,
lock: sync.Mutex{}, lock: sync.Mutex{},
mjService: mjService, mjService: mjService,
clients: types.NewLMap[string, *types.WsClient](),
} }
h.App = app h.App = app
return &h return &h
@ -72,6 +76,23 @@ type notifyData struct {
Progress int `json:"progress"` Progress int `json:"progress"`
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
// 关闭旧的连接
if h.clients.Has(sessionId) {
h.clients.Get(sessionId).Close()
}
h.clients.Put(sessionId, client)
}
func (h *MidJourneyHandler) Notify(c *gin.Context) { func (h *MidJourneyHandler) Notify(c *gin.Context) {
token := c.GetHeader("Authorization") token := c.GetHeader("Authorization")
if token != h.App.Config.ExtConfig.Token { if token != h.App.Config.ExtConfig.Token {
@ -154,8 +175,23 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
return res.Error, false return res.Error, false
} }
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
if err == nil {
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
// 推送任务到前端
client := h.clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
} else if task.Src == service.TaskSrcChat { // 聊天任务 } else if task.Src == service.TaskSrcChat { // 聊天任务
wsClient := h.App.MjTaskClients.Get(task.Id) wsClient := h.App.MjTaskClients.Get(task.SessionId)
if data.Status == Finished { if data.Status == Finished {
if wsClient != nil && data.ReferenceId != "" { if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
@ -216,7 +252,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
// 本次绘画完毕,移除客户端 // 本次绘画完毕,移除客户端
h.App.MjTaskClients.Delete(task.Id) h.App.MjTaskClients.Delete(task.SessionId)
} else { } else {
// 使用代理临时转发图片 // 使用代理临时转发图片
if data.Image.URL != "" { if data.Image.URL != "" {
@ -235,15 +271,16 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) { func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct { var data struct {
Prompt string `json:"prompt"` SessionId string `json:"session_id"`
Rate string `json:"rate"` Prompt string `json:"prompt"`
Model string `json:"model"` Rate string `json:"rate"`
Chaos int `json:"chaos"` Model string `json:"model"`
Raw bool `json:"raw"` Chaos int `json:"chaos"`
Seed int64 `json:"seed"` Raw bool `json:"raw"`
Stylize int `json:"stylize"` Seed int64 `json:"seed"`
Img string `json:"img"` Stylize int `json:"stylize"`
Weight float32 `json:"weight"` Img string `json:"img"`
Weight float32 `json:"weight"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
@ -268,6 +305,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
prompt += fmt.Sprintf(" --iw %f", data.Weight) prompt += fmt.Sprintf(" --iw %f", data.Weight)
} }
} }
if data.Raw {
prompt += " --style raw"
}
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
prompt += data.Model prompt += data.Model
} }
@ -287,12 +327,23 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
} }
h.mjService.PushTask(service.MjTask{ h.mjService.PushTask(service.MjTask{
Id: fmt.Sprintf("%d", job.Id), Id: int(job.Id),
Src: service.TaskSrcImg, SessionId: data.SessionId,
Type: service.Image, Src: service.TaskSrcImg,
Prompt: prompt, Type: service.Image,
UserId: userId, Prompt: prompt,
UserId: userId,
}) })
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@ -317,7 +368,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
} }
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
jobId := data.SessionId jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := service.TaskSrc(data.Src) src := service.TaskSrc(data.Src)
if src == service.TaskSrcImg { if src == service.TaskSrcImg {
@ -330,14 +381,25 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if res := h.db.Create(&job); res.Error == nil { if res := h.db.Create(&job); res.Error == nil {
jobId = fmt.Sprintf("%d", job.Id) jobId = int(job.Id)
} else { } else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error()) resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return return
} }
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
} }
h.mjService.PushTask(service.MjTask{ h.mjService.PushTask(service.MjTask{
Id: jobId, Id: jobId,
SessionId: data.SessionId,
Src: src, Src: src,
Type: service.Upscale, Type: service.Upscale,
Prompt: data.Prompt, Prompt: data.Prompt,
@ -358,6 +420,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
h.App.MjTaskClients.Put(data.SessionId, wsClient) h.App.MjTaskClients.Put(data.SessionId, wsClient)
} }
} }
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@ -370,7 +433,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
} }
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
jobId := data.SessionId jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := service.TaskSrc(data.Src) src := service.TaskSrc(data.Src)
if src == service.TaskSrcImg { if src == service.TaskSrcImg {
@ -384,14 +447,25 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if res := h.db.Create(&job); res.Error == nil { if res := h.db.Create(&job); res.Error == nil {
jobId = fmt.Sprintf("%d", job.Id) jobId = int(job.Id)
} else { } else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error()) resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return return
} }
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
} }
h.mjService.PushTask(service.MjTask{ h.mjService.PushTask(service.MjTask{
Id: jobId, Id: jobId,
SessionId: data.SessionId,
Src: src, Src: src,
Type: service.Variation, Type: service.Variation,
Prompt: data.Prompt, Prompt: data.Prompt,

View File

@ -22,14 +22,14 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
logger.Infof("MJ 绘画参数:%+v", params) logger.Infof("MJ 绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"]) prompt := utils.InterfaceToString(params["prompt"])
f.service.PushTask(service.MjTask{ f.service.PushTask(service.MjTask{
Id: utils.InterfaceToString(params["session_id"]), SessionId: utils.InterfaceToString(params["session_id"]),
Src: service.TaskSrcChat, Src: service.TaskSrcChat,
Type: service.Image, Type: service.Image,
Prompt: prompt, Prompt: prompt,
UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0), UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0), RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
Icon: utils.InterfaceToString(params["icon"]), Icon: utils.InterfaceToString(params["icon"]),
ChatId: utils.InterfaceToString(params["chat_id"]), ChatId: utils.InterfaceToString(params["chat_id"]),
}) })
return prompt, nil return prompt, nil
} }

View File

@ -41,7 +41,8 @@ const (
) )
type MjTask struct { type MjTask struct {
Id string `json:"id"` Id int `json:"id"`
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"` Src TaskSrc `json:"src"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`

View File

@ -12,7 +12,7 @@ import (
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
// ReplyChunkMessage 回复客户片段端消息 // ReplyChunkMessage 回复客户片段端消息
func ReplyChunkMessage(client *types.WsClient, message types.WsMessage) { func ReplyChunkMessage(client *types.WsClient, message interface{}) {
msg, err := json.Marshal(message) msg, err := json.Marshal(message)
if err != nil { if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error()) logger.Errorf("Error for decoding json data: %v", err.Error())

View File

@ -185,31 +185,15 @@
display: flex; display: flex;
justify-content: center; justify-content: center;
align-items: center; align-items: center;
background-color: rgba(0,0,0,0.5);
} }
.page-mj .inner .task-list-box .running-job-list .job-item .job-item-inner .progress span { .page-mj .inner .task-list-box .running-job-list .job-item .job-item-inner .progress span {
font-size: 20px; font-size: 20px;
color: #fff; color: #fff;
} }
.page-mj .inner .task-list-box .running-job-list .job-item .el-image { .page-mj .inner .task-list-box .finish-job-list .job-item {
width: 100%; width: 100%;
height: 100%; height: 100%;
} }
.page-mj .inner .task-list-box .running-job-list .job-item .el-image .image-slot {
display: flex;
flex-flow: column;
justify-content: center;
align-items: center;
height: 100%;
color: #fff;
}
.page-mj .inner .task-list-box .running-job-list .job-item .el-image .image-slot .iconfont {
font-size: 50px;
margin-bottom: 10px;
}
.page-mj .inner .task-list-box .finish-job-list .job-item {
margin-bottom: 20px;
}
.page-mj .inner .task-list-box .finish-job-list .job-item .opt .opt-line { .page-mj .inner .task-list-box .finish-job-list .job-item .opt .opt-line {
margin: 6px 0; margin: 6px 0;
} }
@ -233,6 +217,37 @@
.page-mj .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li a:hover { .page-mj .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li a:hover {
background-color: #6d6f78; background-color: #6d6f78;
} }
.page-mj .inner .task-list-box .el-image {
width: 100%;
height: 100%;
max-height: 240px;
}
.page-mj .inner .task-list-box .el-image img {
height: 240px;
}
.page-mj .inner .task-list-box .el-image .el-image-viewer__wrapper img {
width: auto;
height: auto;
}
.page-mj .inner .task-list-box .el-image .image-slot {
display: flex;
flex-flow: column;
justify-content: center;
align-items: center;
height: 100%;
min-height: 200px;
color: #fff;
}
.page-mj .inner .task-list-box .el-image .image-slot .iconfont {
font-size: 50px;
margin-bottom: 10px;
}
.page-mj .inner .task-list-box .el-image.upscale {
max-height: 304px;
}
.page-mj .inner .task-list-box .el-image.upscale img {
height: 304px;
}
.mj-list-item-prompt .el-icon { .mj-list-item-prompt .el-icon {
margin-left: 10px; margin-left: 10px;
cursor: pointer; cursor: pointer;

View File

@ -247,7 +247,6 @@
.finish-job-list { .finish-job-list {
.job-item { .job-item {
margin-bottom 20px
width 100% width 100%
height 100% height 100%
@ -316,6 +315,14 @@
} }
} }
} }
.el-image.upscale {
max-height 304px
img {
height 304px
}
}
} }
} }

View File

@ -1,7 +1,7 @@
import axios from 'axios' import axios from 'axios'
import {getAdminToken, getSessionId, getUserToken} from "@/store/session"; import {getAdminToken, getSessionId, getUserToken} from "@/store/session";
axios.defaults.timeout = 10000 axios.defaults.timeout = 30000
axios.defaults.baseURL = process.env.VUE_APP_API_HOST axios.defaults.baseURL = process.env.VUE_APP_API_HOST
axios.defaults.withCredentials = true; axios.defaults.withCredentials = true;
axios.defaults.headers.post['Content-Type'] = 'application/json' axios.defaults.headers.post['Content-Type'] = 'application/json'

View File

@ -285,10 +285,10 @@
placement="top-start" placement="top-start"
title="提示词" title="提示词"
:width="240" :width="240"
trigger="hover" trigger="click"
> >
<template #reference> <template #reference>
<el-image :src="scope.item.img_url" <el-image :src="scope.item.img_url" :class="scope.item.type === 'upscale'?'upscale':''"
:zoom-rate="1.2" :zoom-rate="1.2"
:preview-src-list="previewImgList" :preview-src-list="previewImgList"
fit="cover" fit="cover"
@ -319,7 +319,7 @@
</template> </template>
</el-popover> </el-popover>
<div class="opt"> <div class="opt" v-if="scope.item.type !== 'upscale'">
<div class="opt-line"> <div class="opt-line">
<ul> <ul>
<li><a @click="upscale(1,scope.item)">U1</a></li> <li><a @click="upscale(1,scope.item)">U1</a></li>
@ -352,7 +352,7 @@
</template> </template>
<script setup> <script setup>
import {onMounted, ref} from "vue" import {nextTick, onMounted, ref} from "vue"
import {DeleteFilled, DocumentCopy, InfoFilled, Picture, Plus} from "@element-plus/icons-vue"; import {DeleteFilled, DocumentCopy, InfoFilled, Picture, Plus} from "@element-plus/icons-vue";
import Compressor from "compressorjs"; import Compressor from "compressorjs";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
@ -361,7 +361,9 @@ import ItemList from "@/components/ItemList.vue";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {checkSession} from "@/action/session"; import {checkSession} from "@/action/session";
import {useRouter} from "vue-router"; import {useRouter} from "vue-router";
import {getSessionId} from "@/store/session"; import {getSessionId, getUserToken} from "@/store/session";
import {randString} from "@/utils/libs";
import hl from "highlight.js";
const listBoxHeight = ref(window.innerHeight - 40) const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150) const mjBoxHeight = ref(window.innerHeight - 150)
@ -395,10 +397,85 @@ const runningJobs = ref([])
const finishedJobs = ref([]) const finishedJobs = ref([])
const previewImgList = ref([]) const previewImgList = ref([])
const router = useRouter() const router = useRouter()
const socket = ref(null)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8");
reader.onload = () => {
const data = JSON.parse(String(reader.result));
console.log(data)
let isNew = false
if (data.progress === 100) {
for (let i = 0; i < finishedJobs.value.length; i++) {
if (finishedJobs.value[i].id === data.id) {
isNew = false
break
}
}
if (isNew) {
finishedJobs.value.unshift(data)
}
} else {
for (let i = 0; i < runningJobs.value; i++) {
if (runningJobs.value[i].id === data.id) {
isNew = false
runningJobs.value[i] = data
break
}
}
if (isNew) {
runningJobs.value.push(data)
}
}
}
}
});
_socket.addEventListener('close', () => {
connect()
});
}
onMounted(() => { onMounted(() => {
checkSession().then(() => { checkSession().then(() => {
fetchFinishedJobs() //
fetchRunningJobs() httpGet("/api/mj/jobs?status=1").then(res => {
if (finishedJobs.value.length !== res.data.length) {
finishedJobs.value = res.data
}
previewImgList.value = []
for (let index in finishedJobs.value) {
previewImgList.value.push(finishedJobs.value[index]["img_url"])
}
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
//
httpGet("/api/mj/jobs?status=0").then(res => {
if (runningJobs.value.length !== res.data.length) {
runningJobs.value = res.data
}
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}).catch(() => { }).catch(() => {
router.push('/login') router.push('/login')
}); });
@ -413,32 +490,6 @@ onMounted(() => {
}) })
}) })
const fetchFinishedJobs = () => {
httpGet("/api/mj/jobs?status=1").then(res => {
finishedJobs.value = res.data
previewImgList.value = []
for (let index in finishedJobs.value) {
previewImgList.value.push(finishedJobs.value[index]["img_url"])
}
setTimeout(() => {
fetchFinishedJobs()
}, 2000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
const fetchRunningJobs = () => {
httpGet("/api/mj/jobs?status=0").then(res => {
runningJobs.value = res.data
setTimeout(() => {
fetchRunningJobs()
}, 1000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
// //
const changeRate = (item) => { const changeRate = (item) => {
params.value.rate = item.value params.value.rate = item.value
@ -489,6 +540,10 @@ const generate = () => {
promptRef.value.focus() promptRef.value.focus()
return ElMessage.error("请输入绘画提示词!") return ElMessage.error("请输入绘画提示词!")
} }
if (params.value.model.indexOf("niji") !== -1 && params.value.raw) {
return ElMessage.error("动漫模型不允许启用原始模式")
}
params.value.session_id = getSessionId()
httpPost("/api/mj/image", params.value).then(() => { httpPost("/api/mj/image", params.value).then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...") ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
}).catch(e => { }).catch(e => {