sd websocket refactor is finished

This commit is contained in:
RockYang 2024-09-27 18:28:54 +08:00
parent 2debe7e927
commit 8fffa60569
7 changed files with 45 additions and 76 deletions

View File

@ -43,12 +43,14 @@ type MjTask struct {
type SdTask struct { type SdTask struct {
Id int `json:"id"` // job 数据库ID Id int `json:"id"` // job 数据库ID
Type TaskType `json:"type"` Type TaskType `json:"type"`
ClientId string `json:"client_id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Params SdTaskParams `json:"params"` Params SdTaskParams `json:"params"`
RetryCount int `json:"retry_count"` RetryCount int `json:"retry_count"`
} }
type SdTaskParams struct { type SdTaskParams struct {
ClientId string `json:"client_id"` // 客户端ID
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词 Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词 NegPrompt string `json:"neg_prompt"` // 反向提示词

View File

@ -144,17 +144,13 @@ func (h *SdJobHandler) Image(c *gin.Context) {
} }
h.sdService.PushTask(types.SdTask{ h.sdService.PushTask(types.SdTask{
Id: int(job.Id), Id: int(job.Id),
Type: types.TaskImage, ClientId: data.ClientId,
Params: params, Type: types.TaskImage,
UserId: userId, Params: params,
UserId: userId,
}) })
client := h.sdService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power // update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume, Type: types.PowerConsume,

View File

@ -34,17 +34,17 @@ type Service struct {
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
leveldb *store.LevelDB leveldb *store.LevelDB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client wsService *service.WebsocketService
} }
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service { func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
return &Service{ return &Service{
httpClient: req.C(), httpClient: req.C(),
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli), taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli), notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db, db: db,
leveldb: levelDB, leveldb: levelDB,
Clients: types.NewLMap[uint, *types.WsClient](), wsService: wsService,
uploadManager: manager, uploadManager: manager,
} }
} }
@ -90,7 +90,7 @@ func (s *Service) Run() {
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
// 通知前端,任务失败 // 通知前端,任务失败
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue continue
} }
} }
@ -213,7 +213,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished // task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据 // 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId) _ = s.leveldb.Delete(task.Params.TaskId)
return nil return nil
@ -223,7 +223,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
if err == nil && resp.Progress > 0 { if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号 // 发送更新状态信号
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning}) s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据 // 保存预览图片数据
if resp.CurrentImage != "" { if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
@ -267,14 +267,11 @@ func (s *Service) CheckTaskNotify() {
if err != nil { if err != nil {
continue continue
} }
client := s.Clients.Get(uint(message.UserId)) client := s.wsService.Clients.Get(message.ClientId)
if client == nil { if client == nil {
continue continue
} }
err = client.Send([]byte(message.Message)) utils.SendChannelMsg(client, types.ChSd, message.Message)
if err != nil {
continue
}
} }
}() }()
} }

View File

@ -8,9 +8,10 @@ const (
) )
type NotifyMessage struct { type NotifyMessage struct {
UserId int `json:"user_id"` UserId int `json:"user_id"`
JobId int `json:"job_id"` ClientId string `json:"client_id"`
Message string `json:"message"` JobId int `json:"job_id"`
Message string `json:"message"`
} }
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"

View File

@ -71,12 +71,12 @@ const connect = () => {
handler.value = setInterval(() => { handler.value = setInterval(() => {
_socket.send(JSON.stringify({"type":"ping"})) _socket.send(JSON.stringify({"type":"ping"}))
},5000) },5000)
})
for (const key in store.messageHandlers) { for (const key in store.messageHandlers) {
console.log(key, store.messageHandlers[key]) console.log(key, store.messageHandlers[key])
store.setMessageHandler(store.messageHandlers[key]) store.setMessageHandler(store.messageHandlers[key])
} }
});
_socket.addEventListener('close', () => { _socket.addEventListener('close', () => {
store.setSocket(null) store.setSocket(null)

View File

@ -487,11 +487,11 @@
<script setup> <script setup>
import {nextTick, onMounted, onUnmounted, ref} from "vue" import {nextTick, onMounted, onUnmounted, ref} from "vue"
import {Delete, DocumentCopy, InfoFilled, Orange, Picture} from "@element-plus/icons-vue"; import {Delete, DocumentCopy, InfoFilled, Orange} from "@element-plus/icons-vue";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
import {ElMessage, ElMessageBox, ElNotification} from "element-plus"; import {ElMessage, ElMessageBox} from "element-plus";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {checkSession, getSystemInfo} from "@/store/cache"; import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
import {useRouter} from "vue-router"; import {useRouter} from "vue-router";
import {getSessionId} from "@/store/session"; import {getSessionId} from "@/store/session";
import {useSharedStore} from "@/store/sharedata"; import {useSharedStore} from "@/store/sharedata";
@ -520,6 +520,7 @@ const samplers = ["Euler a", "DPM++ 2S a", "DPM++ 2M", "DPM++ SDE", "DPM++ 2M SD
const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"] const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"]
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"] const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"]
const params = ref({ const params = ref({
client_id: getClientId(),
width: 1024, width: 1024,
height: 1024, height: 1024,
sampler: samplers[0], sampler: samplers[0],
@ -547,46 +548,7 @@ if (_params) {
const power = ref(0) const power = ref(0)
const sdPower = ref(0) // SD const sdPower = ref(0) // SD
const socket = ref(null)
const userId = ref(0) const userId = ref(0)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/sd/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8")
reader.onload = () => {
const message = String(reader.result)
if (message === "FINISH" || message === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs()
}
nextTick(() => fetchRunningJobs())
}
}
});
_socket.addEventListener('close', () => {
if (socket.value !== null) {
connect()
}
})
}
const clipboard = ref(null) const clipboard = ref(null)
onMounted(() => { onMounted(() => {
initData() initData()
@ -605,14 +567,25 @@ onMounted(() => {
}).catch(e => { }).catch(e => {
ElMessage.error("获取系统配置失败:" + e.message) ElMessage.error("获取系统配置失败:" + e.message)
}) })
store.addMessageHandler("sd",(data) => {
//
if (data.channel !== "sd" || data.clientId !== getClientId()) {
return
}
if (data.body === "FINISH" || data.body === "FAIL") {
page.value = 0
isOver.value = false
fetchFinishJobs()
}
nextTick(() => fetchRunningJobs())
})
}) })
onUnmounted(() => { onUnmounted(() => {
clipboard.value.destroy() clipboard.value.destroy()
if (socket.value !== null) {
socket.value.close()
socket.value = null
}
}) })
@ -624,7 +597,6 @@ const initData = () => {
page.value = 0 page.value = 0
fetchRunningJobs() fetchRunningJobs()
fetchFinishJobs() fetchFinishJobs()
connect()
}).catch(() => { }).catch(() => {
}); });
} }
@ -694,6 +666,7 @@ const generate = () => {
httpPost("/api/sd/image", params.value).then(() => { httpPost("/api/sd/image", params.value).then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...") ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
power.value -= sdPower.value power.value -= sdPower.value
fetchRunningJobs()
}).catch(e => { }).catch(e => {
ElMessage.error("任务推送失败:" + e.message) ElMessage.error("任务推送失败:" + e.message)
}) })

View File

@ -194,7 +194,7 @@ const fetchData = () => {
const add = function () { const add = function () {
showDialog.value = true showDialog.value = true
title.value = "新增 API KEY" title.value = "新增 API KEY"
item.value = {enabled: true,api_url: "https://api.chat-plus.net"} item.value = {enabled: true,api_url: "https://api.geekai.pro"}
} }
const edit = function (row) { const edit = function (row) {