sd websocket refactor is finished

This commit is contained in:
RockYang 2024-09-27 18:28:54 +08:00
parent d95fab11be
commit 9edb3d0a82
7 changed files with 45 additions and 76 deletions

View File

@ -43,12 +43,14 @@ type MjTask struct {
type SdTask struct { type SdTask struct {
Id int `json:"id"` // job 数据库ID Id int `json:"id"` // job 数据库ID
Type TaskType `json:"type"` Type TaskType `json:"type"`
ClientId string `json:"client_id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Params SdTaskParams `json:"params"` Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"` RetryCount int `json:"retry_count"`
} }
type SdTaskParams struct { type SdTaskParams struct {
ClientId string `json:"client_id"` // 客户端ID
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词 Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词 NegPrompt string `json:"neg_prompt"` // 反向提示词

View File

@ -145,16 +145,12 @@ func (h *SdJobHandler) Image(c *gin.Context) {
h.sdService.PushTask(types.SdTask{ h.sdService.PushTask(types.SdTask{
Id: int(job.Id), Id: int(job.Id),
ClientId: data.ClientId,
Type: types.TaskImage, Type: types.TaskImage,
Params: params, Params: params,
UserId: userId, UserId: userId,
}) })
client := h.sdService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power // update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume, Type: types.PowerConsume,

View File

@ -34,17 +34,17 @@ type Service struct {
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
leveldb *store.LevelDB leveldb *store.LevelDB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client wsService *service.WebsocketService
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{ return &Service{
httpClient: req.C(), httpClient: req.C(),
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli), notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db, db: db,
leveldb: levelDB, leveldb: levelDB,
Clients: types.NewLMap[uint, *types.WsClient](), wsService: wsService,
uploadManager: manager, uploadManager: manager,
} }
} }
@ -90,7 +90,7 @@ func (s *Service) Run() {
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
// 通知前端,任务失败 // 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue continue
} }
} }
@ -213,7 +213,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished // task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据 // 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId) _ = s.leveldb.Delete(task.Params.TaskId)
return nil return nil
@ -223,7 +223,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
if err == nil && resp.Progress > 0 { if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号 // 发送更新状态信号
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据 // 保存预览图片数据
if resp.CurrentImage != "" { if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
@ -267,14 +267,11 @@ func (s *Service) CheckTaskNotify() {
if err != nil { if err != nil {
continue continue
} }
client := s.Clients.Get(uint(message.UserId)) client := s.wsService.Clients.Get(message.ClientId)
if client == nil { if client == nil {
continue continue
} }
err = client.Send([]byte(message.Message)) utils.SendChannelMsg(client, types.ChSd, message.Message)
if err != nil {
continue
}
} }
}() }()
} }

View File

@ -9,6 +9,7 @@ const (
type NotifyMessage struct { type NotifyMessage struct {
UserId int `json:"user_id"` UserId int `json:"user_id"`
ClientId string `json:"client_id"`
JobId int `json:"job_id"` JobId int `json:"job_id"`
Message string `json:"message"` Message string `json:"message"`
} }

View File

@ -71,12 +71,12 @@ const connect = () => {
handler.value = setInterval(() => { handler.value = setInterval(() => {
_socket.send(JSON.stringify({"type":"ping"})) _socket.send(JSON.stringify({"type":"ping"}))
},5000) },5000)
})
for (const key in store.messageHandlers) { for (const key in store.messageHandlers) {
console.log(key, store.messageHandlers[key]) console.log(key, store.messageHandlers[key])
store.setMessageHandler(store.messageHandlers[key]) store.setMessageHandler(store.messageHandlers[key])
} }
});
_socket.addEventListener('close', () => { _socket.addEventListener('close', () => {
store.setSocket(null) store.setSocket(null)

View File

@ -487,11 +487,11 @@
<script setup> <script setup>
import {nextTick, onMounted, onUnmounted, ref} from "vue" import {nextTick, onMounted, onUnmounted, ref} from "vue"
import {Delete, DocumentCopy, InfoFilled, Orange, Picture} from "@element-plus/icons-vue"; import {Delete, DocumentCopy, InfoFilled, Orange} from "@element-plus/icons-vue";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
import {ElMessage, ElMessageBox, ElNotification} from "element-plus"; import {ElMessage, ElMessageBox} from "element-plus";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {checkSession, getSystemInfo} from "@/store/cache"; import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
import {useRouter} from "vue-router"; import {useRouter} from "vue-router";
import {getSessionId} from "@/store/session"; import {getSessionId} from "@/store/session";
import {useSharedStore} from "@/store/sharedata"; import {useSharedStore} from "@/store/sharedata";
@ -520,6 +520,7 @@ const samplers = ["Euler a", "DPM++ 2S a", "DPM++ 2M", "DPM++ SDE", "DPM++ 2M SD
const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"] const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"]
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"] const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"]
const params = ref({ const params = ref({
client_id: getClientId(),
width: 1024, width: 1024,
height: 1024, height: 1024,
sampler: samplers[0], sampler: samplers[0],
@ -547,46 +548,7 @@ if (_params) {
const power = ref(0) const power = ref(0)
const sdPower = ref(0) // SD const sdPower = ref(0) // SD
const socket = ref(null)
const userId = ref(0) const userId = ref(0)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/sd/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8")
reader.onload = () => {
const message = String(reader.result)
if (message === "FINISH" || message === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs()
}
nextTick(() => fetchRunningJobs())
}
}
});
_socket.addEventListener('close', () => {
if (socket.value !== null) {
connect()
}
})
}
const clipboard = ref(null) const clipboard = ref(null)
onMounted(() => { onMounted(() => {
initData() initData()
@ -605,14 +567,25 @@ onMounted(() => {
}).catch(e => { }).catch(e => {
ElMessage.error("获取系统配置失败:" + e.message) ElMessage.error("获取系统配置失败:" + e.message)
}) })
store.addMessageHandler("sd",(data) => {
//
if (data.channel !== "sd" || data.clientId !== getClientId()) {
return
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs()
}
nextTick(() => fetchRunningJobs())
})
}) })
onUnmounted(() => { onUnmounted(() => {
clipboard.value.destroy() clipboard.value.destroy()
if (socket.value !== null) {
socket.value.close()
socket.value = null
}
}) })
@ -624,7 +597,6 @@ const initData = () => {
page.value = 0 page.value = 0
fetchRunningJobs() fetchRunningJobs()
fetchFinishJobs() fetchFinishJobs()
connect()
}).catch(() => { }).catch(() => {
}); });
} }
@ -694,6 +666,7 @@ const generate = () => {
httpPost("/api/sd/image", params.value).then(() => { httpPost("/api/sd/image", params.value).then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...") ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
power.value -= sdPower.value power.value -= sdPower.value
fetchRunningJobs()
}).catch(e => { }).catch(e => {
ElMessage.error("任务推送失败:" + e.message) ElMessage.error("任务推送失败:" + e.message)
}) })

View File

@ -194,7 +194,7 @@ const fetchData = () => {
const add = function () { const add = function () {
showDialog.value = true showDialog.value = true
title.value = "新增 API KEY" title.value = "新增 API KEY"
item.value = {enabled: true,api_url: "https://api.chat-plus.net"} item.value = {enabled: true,api_url: "https://api.geekai.pro"}
} }
const edit = function (row) { const edit = function (row) {