sd websocket refactor is finished

This commit is contained in:
RockYang 2024-09-27 18:28:54 +08:00
parent 2debe7e927
commit 8fffa60569
7 changed files with 45 additions and 76 deletions

View File

@ -43,12 +43,14 @@ type MjTask struct {
type SdTask struct {
Id int `json:"id"` // job 数据库ID
Type TaskType `json:"type"`
ClientId string `json:"client_id"`
UserId int `json:"user_id"`
Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"`
}
type SdTaskParams struct {
ClientId string `json:"client_id"` // 客户端ID
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词

View File

@ -144,17 +144,13 @@ func (h *SdJobHandler) Image(c *gin.Context) {
}
h.sdService.PushTask(types.SdTask{
Id: int(job.Id),
Type: types.TaskImage,
Params: params,
UserId: userId,
Id: int(job.Id),
ClientId: data.ClientId,
Type: types.TaskImage,
Params: params,
UserId: userId,
})
client := h.sdService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,

View File

@ -34,17 +34,17 @@ type Service struct {
db *gorm.DB
uploadManager *oss.UploaderManager
leveldb *store.LevelDB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
wsService *service.WebsocketService
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{
httpClient: req.C(),
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db,
leveldb: levelDB,
Clients: types.NewLMap[uint, *types.WsClient](),
wsService: wsService,
uploadManager: manager,
}
}
@ -90,7 +90,7 @@ func (s *Service) Run() {
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue
}
}
@ -213,7 +213,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
@ -223,7 +223,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
@ -267,14 +267,11 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
utils.SendChannelMsg(client, types.ChSd, message.Message)
}
}()
}

View File

@ -8,9 +8,10 @@ const (
)
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
UserId int `json:"user_id"`
ClientId string `json:"client_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"

View File

@ -71,12 +71,12 @@ const connect = () => {
handler.value = setInterval(() => {
_socket.send(JSON.stringify({"type":"ping"}))
},5000)
})
for (const key in store.messageHandlers) {
console.log(key, store.messageHandlers[key])
store.setMessageHandler(store.messageHandlers[key])
}
});
for (const key in store.messageHandlers) {
console.log(key, store.messageHandlers[key])
store.setMessageHandler(store.messageHandlers[key])
}
_socket.addEventListener('close', () => {
store.setSocket(null)

View File

@ -487,11 +487,11 @@
<script setup>
import {nextTick, onMounted, onUnmounted, ref} from "vue"
import {Delete, DocumentCopy, InfoFilled, Orange, Picture} from "@element-plus/icons-vue";
import {Delete, DocumentCopy, InfoFilled, Orange} from "@element-plus/icons-vue";
import {httpGet, httpPost} from "@/utils/http";
import {ElMessage, ElMessageBox, ElNotification} from "element-plus";
import {ElMessage, ElMessageBox} from "element-plus";
import Clipboard from "clipboard";
import {checkSession, getSystemInfo} from "@/store/cache";
import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
import {useRouter} from "vue-router";
import {getSessionId} from "@/store/session";
import {useSharedStore} from "@/store/sharedata";
@ -520,6 +520,7 @@ const samplers = ["Euler a", "DPM++ 2S a", "DPM++ 2M", "DPM++ SDE", "DPM++ 2M SD
const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"]
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"]
const params = ref({
client_id: getClientId(),
width: 1024,
height: 1024,
sampler: samplers[0],
@ -547,46 +548,7 @@ if (_params) {
const power = ref(0)
const sdPower = ref(0) // SD
const socket = ref(null)
const userId = ref(0)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/sd/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8")
reader.onload = () => {
const message = String(reader.result)
if (message === "FINISH" || message === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs()
}
nextTick(() => fetchRunningJobs())
}
}
});
_socket.addEventListener('close', () => {
if (socket.value !== null) {
connect()
}
})
}
const clipboard = ref(null)
onMounted(() => {
initData()
@ -605,14 +567,25 @@ onMounted(() => {
}).catch(e => {
ElMessage.error("获取系统配置失败:" + e.message)
})
store.addMessageHandler("sd",(data) => {
//
if (data.channel !== "sd" || data.clientId !== getClientId()) {
return
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs()
}
nextTick(() => fetchRunningJobs())
})
})
onUnmounted(() => {
clipboard.value.destroy()
if (socket.value !== null) {
socket.value.close()
socket.value = null
}
})
@ -624,7 +597,6 @@ const initData = () => {
page.value = 0
fetchRunningJobs()
fetchFinishJobs()
connect()
}).catch(() => {
});
}
@ -694,6 +666,7 @@ const generate = () => {
httpPost("/api/sd/image", params.value).then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
power.value -= sdPower.value
fetchRunningJobs()
}).catch(e => {
ElMessage.error("任务推送失败:" + e.message)
})

View File

@ -194,7 +194,7 @@ const fetchData = () => {
const add = function () {
showDialog.value = true
title.value = "新增 API KEY"
item.value = {enabled: true,api_url: "https://api.chat-plus.net"}
item.value = {enabled: true,api_url: "https://api.geekai.pro"}
}
const edit = function (row) {