mj websocket refactor is ready

This commit is contained in:
RockYang 2024-09-29 07:51:08 +08:00
parent 8fffa60569
commit 00a8bc6784
7 changed files with 63 additions and 149 deletions

View File

@ -24,8 +24,9 @@ const (
// MjTask MidJourney 任务
type MjTask struct {
Id uint `json:"id"`
TaskId string `json:"task_id"`
Id uint `json:"id"` // 任务ID
TaskId string `json:"task_id"` // 中转任务ID
ClientId string `json:"client_id"`
ImgArr []string `json:"img_arr"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`

View File

@ -8,7 +8,6 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"fmt"
"geekai/core"
"geekai/core/types"
@ -67,6 +66,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct {
TaskType string `json:"task_type"`
ClientId string `json:"client_id"`
Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"`
@ -177,6 +177,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
h.mjService.PushTask(types.MjTask{
Id: job.Id,
ClientId: data.ClientId,
TaskId: taskId,
Type: types.TaskType(data.TaskType),
Prompt: data.Prompt,
@ -187,11 +188,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
@ -208,6 +204,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
type reqVo struct {
Index int `json:"index"`
ClientId string `json:"client_id"`
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
@ -244,6 +241,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
h.mjService.PushTask(types.MjTask{
Id: job.Id,
ClientId: data.ClientId,
Type: types.TaskUpscale,
UserId: userId,
ChannelId: data.ChannelId,
@ -253,11 +251,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
@ -305,6 +298,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
h.mjService.PushTask(types.MjTask{
Id: job.Id,
Type: types.TaskVariation,
ClientId: data.ClientId,
UserId: userId,
Index: data.Index,
ChannelId: data.ChannelId,
@ -313,11 +307,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "mid-journey",
@ -397,14 +386,6 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
if err != nil {
continue
}
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
jobs = append(jobs, job)
}
return nil, vo.NewPage(total, page, pageSize, jobs)
@ -449,11 +430,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err)
}
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}

View File

@ -28,18 +28,20 @@ type Service struct {
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
wsService *service.WebsocketService
uploaderManager *oss.UploaderManager
clientIds map[uint]string
}
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service {
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService) *Service {
return &Service{
db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
client: client,
Clients: types.NewLMap[uint, *types.WsClient](),
wsService: wsService,
uploaderManager: manager,
clientIds: map[uint]string{},
}
}
@ -77,6 +79,7 @@ func (s *Service) Run() {
if task.Mode == "" {
task.Mode = "fast"
}
s.clientIds[task.Id] = task.ClientId
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
@ -119,7 +122,7 @@ func (s *Service) Run() {
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
}
logger.Infof("任务提交成功:%+v", res)
@ -166,14 +169,11 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
cli := s.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChMj, message.Message)
}
}()
}
@ -211,14 +211,11 @@ func (s *Service) DownloadImages() {
v.ImgURL = imgURL
s.db.Updates(&v)
cli := s.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(service.TaskStatusFinished))
if err != nil {
continue
}
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[v.Id],
UserId: v.UserId,
JobId: int(v.Id),
Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 5)
@ -264,7 +261,11 @@ func (s *Service) SyncTaskProgress() {
"err_msg": task.FailReason,
})
logger.Errorf("task failed: %v", task.FailReason)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: service.TaskStatusFailed})
continue
}
@ -289,7 +290,11 @@ func (s *Service) SyncTaskProgress() {
if job.Progress == 100 {
message = service.TaskStatusFinished
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: message})
}
}

View File

@ -6,6 +6,6 @@ VUE_APP_ADMIN_USER=admin
VUE_APP_ADMIN_PASS=admin123
VUE_APP_KEY_PREFIX=GeekAI_DEV_
VUE_APP_TITLE="Geek-AI 创作系统"
VUE_APP_VERSION=v4.1.4
VUE_APP_VERSION=v4.1.5
VUE_APP_DOCS_URL=https://docs.geekai.me
VUE_APP_GIT_URL=https://github.com/yangjian102621/geekai

View File

@ -1,28 +1,14 @@
<template>
<div class="running-job-list">
<div class="running-job-box" v-if="list.length > 0">
<div class="job-item" v-for="item in list">
<div class="job-item" v-for="item in list" :key="item.id">
<div v-if="item.progress > 0" class="job-item-inner">
<el-image v-if="item.img_url" :src="item['img_url']" fit="cover" loading="lazy">
<template #placeholder>
<div class="image-slot">
正在加载图片
</div>
</template>
<template #error>
<div class="image-slot">
<el-icon>
<Picture/>
</el-icon>
</div>
</template>
</el-image>
<div class="progress" v-if="item.progress > 0">
<el-progress type="circle" :percentage="item.progress" :width="100"
color="#47fff1"/>
</div>
<div class="progress">
<el-progress type="circle" :percentage="item.progress" :width="100"
color="#47fff1"/>
</div>
</div>
<el-image fit="cover" v-else>
<template #error>
@ -39,11 +25,7 @@
</template>
<script setup>
import {ref} from "vue";
import {CircleCloseFilled, Picture} from "@element-plus/icons-vue";
import {isImage, removeArrayItem, substr} from "@/utils/libs";
import {FormatFileSize, GetFileIcon, GetFileType} from "@/store/system";
// eslint-disable-next-line no-undef
const props = defineProps({
list: {
type: Array,

View File

@ -357,7 +357,6 @@ onMounted(() => {
window.onresize = () => resizeElement();
store.addMessageHandler("chat", (data) => {
console.log(data)
//
if (data.channel !== 'chat' || data.clientId !== getClientId()) {
return
@ -472,8 +471,6 @@ const initData = () => {
for (let item of items) {
if (item.kind === 'file') {
const file = item.getAsFile();
fileFound = true;
const formData = new FormData();
formData.append('file', file);
loading.value = true
@ -490,10 +487,6 @@ const initData = () => {
break;
}
}
if (!fileFound) {
document.getElementById('status').innerText = 'No file found in paste data.';
}
});
}

View File

@ -215,7 +215,7 @@
<div class="param-line">
<div class="img-inline">
<div class="img-list-box">
<div class="img-item" v-for="imgURL in imgList">
<div class="img-item" v-for="imgURL in imgList" :key="imgURL">
<el-image :src="imgURL" fit="cover"/>
<el-button type="danger" :icon="Delete" @click="removeUploadImage(imgURL)" circle/>
</div>
@ -293,7 +293,7 @@
<div class="text">请上传两张以上的图片最多不超过五张超过五张图片请使用图生图功能</div>
<div class="img-inline">
<div class="img-list-box">
<div class="img-item" v-for="imgURL in imgList">
<div class="img-item" v-for="imgURL in imgList" :key="imgURL">
<el-image :src="imgURL" fit="cover"/>
<el-button type="danger" :icon="Delete" @click="removeUploadImage(imgURL)" circle/>
</div>
@ -312,7 +312,7 @@
<div class="text">请上传两张有脸部的图片用左边图片的脸替换右边图片的脸</div>
<div class="img-inline">
<div class="img-list-box">
<div class="img-item" v-for="imgURL in imgList">
<div class="img-item" v-for="imgURL in imgList" :key="imgURL">
<el-image :src="imgURL" fit="cover"/>
<el-button type="danger" :icon="Delete" @click="removeUploadImage(imgURL)" circle/>
</div>
@ -602,13 +602,13 @@
</template>
<script setup>
import {onMounted, onUnmounted, ref} from "vue"
import {nextTick, onMounted, onUnmounted, ref} from "vue"
import {ChromeFilled, Delete, DocumentCopy, InfoFilled, Picture, Plus, UploadFilled} from "@element-plus/icons-vue";
import Compressor from "compressorjs";
import {httpGet, httpPost} from "@/utils/http";
import {ElMessage, ElMessageBox, ElNotification} from "element-plus";
import Clipboard from "clipboard";
import {checkSession, getSystemInfo} from "@/store/cache";
import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
import {useRouter} from "vue-router";
import {getSessionId} from "@/store/session";
import {copyObj, removeArrayItem} from "@/utils/libs";
@ -678,6 +678,7 @@ const options = [
const router = useRouter()
const initParams = {
client_id: getClientId(),
task_type: "image",
rate: rates[0].value,
model: models[0].value,
@ -704,66 +705,10 @@ const activeName = ref('txt2img')
const runningJobs = ref([])
const finishedJobs = ref([])
const socket = ref(null)
const power = ref(0)
const userId = ref(0)
const isLogin = ref(false)
const heartbeatHandle = ref(null)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
//
const sendHeartbeat = () => {
clearTimeout(heartbeatHandle.value)
new Promise((resolve, reject) => {
if (socket.value !== null) {
socket.value.send(JSON.stringify({type: "heartbeat", content: "ping"}))
}
resolve("success")
}).then(() => {
heartbeatHandle.value = setTimeout(() => sendHeartbeat(), 5000)
});
}
const _socket = new WebSocket(host + `/api/mj/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
//
sendHeartbeat()
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8")
reader.onload = () => {
const message = String(reader.result)
if (message === "FINISH" || message === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs(page.value)
}
fetchRunningJobs()
}
}
});
_socket.addEventListener('close', () => {
if (socket.value !== null) {
connect()
}
});
}
const clipboard = ref(null)
onMounted(() => {
initData()
@ -775,14 +720,24 @@ onMounted(() => {
clipboard.value.on('error', () => {
ElMessage.error('复制失败!');
})
store.addMessageHandler("mj",(data) => {
//
if (data.channel !== "mj" || data.clientId !== getClientId()) {
return
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs()
}
nextTick(() => fetchRunningJobs())
})
})
onUnmounted(() => {
clipboard.value.destroy()
if (socket.value !== null) {
socket.value.close()
socket.value = null
}
})
//
@ -794,7 +749,6 @@ const initData = () => {
page.value = 0
fetchRunningJobs()
fetchFinishJobs()
connect()
}).catch(() => {
});
@ -952,6 +906,7 @@ const generate = () => {
httpPost("/api/mj/image", params.value).then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
power.value -= mjPower.value
fetchRunningJobs()
}).catch(e => {
ElMessage.error("任务推送失败:" + e.message)
})
@ -970,6 +925,7 @@ const variation = (index, item) => {
const send = (url, index, item) => {
httpPost(url, {
index: index,
client_id: getClientId(),
channel_id: item.channel_id,
message_id: item.message_id,
message_hash: item.hash,
@ -978,6 +934,7 @@ const send = (url, index, item) => {
}).then(() => {
ElMessage.success("任务推送成功,请耐心等待任务执行...")
power.value -= mjActionPower.value
fetchRunningJobs()
}).catch(e => {
ElMessage.error("任务推送失败:" + e.message)
})