feat: replace http polling with webscoket notify in sd image page

This commit is contained in:
RockYang 2024-02-26 15:45:54 +08:00
parent 668d4c9c64
commit b95dff0751
8 changed files with 189 additions and 124 deletions

View File

@ -155,6 +155,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
c.Request.URL.Path == "/api/mj/notify" || c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/jobs" || c.Request.URL.Path == "/api/sd/jobs" ||
c.Request.URL.Path == "/api/sd/client" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") || strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||

View File

@ -156,6 +156,11 @@ func (h *SdJobHandler) Image(c *gin.Context) {
UserId: userId, UserId: userId,
}) })
client := h.pool.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's img calls // update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
@ -229,6 +234,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
func (h *SdJobHandler) Remove(c *gin.Context) { func (h *SdJobHandler) Remove(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"` ImgURL string `json:"img_url"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
@ -249,6 +255,11 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }

View File

@ -247,6 +247,7 @@ func main() {
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd") group := s.Engine.Group("/api/sd")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.POST("remove", h.Remove) group.POST("remove", h.Remove)

View File

@ -52,6 +52,26 @@ func (p *ServicePool) PushTask(task types.SdTask) {
p.taskQueue.RPush(task) p.taskQueue.RPush(task)
} }
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
if err != nil {
continue
}
client := p.Clients.Get(userId)
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
}()
}
// HasAvailableService check if it has available mj service in pool // HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool { func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0 return len(p.services) > 0

View File

@ -310,4 +310,7 @@ func (s *Service) callback(data CBReq) {
// restore img_calls // restore img_calls
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", data.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", data.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
} }
// 发送更新状态信号
s.notifyQueue.RPush(data.UserId)
} }

View File

@ -216,6 +216,7 @@
} }
} }
.remove { .remove {
display none display none
position absolute position absolute

View File

@ -307,6 +307,7 @@
</div> </div>
<div class="task-list-box"> <div class="task-list-box">
<div class="task-list-inner" :style="{ height: listBoxHeight + 'px' }"> <div class="task-list-inner" :style="{ height: listBoxHeight + 'px' }">
<div class="job-list-box">
<h2>任务列表</h2> <h2>任务列表</h2>
<div class="running-job-list"> <div class="running-job-list">
<ItemList :items="runningJobs" v-if="runningJobs.length > 0" :width="240"> <ItemList :items="runningJobs" v-if="runningJobs.length > 0" :width="240">
@ -374,7 +375,8 @@
<div class="remove"> <div class="remove">
<el-button type="danger" :icon="Delete" @click="removeImage($event,scope.item)" circle/> <el-button type="danger" :icon="Delete" @click="removeImage($event,scope.item)" circle/>
<el-button type="warning" v-if="scope.item.publish" @click="publishImage($event,scope.item, false)" <el-button type="warning" v-if="scope.item.publish"
@click="publishImage($event,scope.item, false)"
circle> circle>
<i class="iconfont icon-cancel-share"></i> <i class="iconfont icon-cancel-share"></i>
</el-button> </el-button>
@ -388,6 +390,7 @@
<el-empty :image-size="100" v-else/> <el-empty :image-size="100" v-else/>
</div> <!-- end finish job list--> </div> <!-- end finish job list-->
</div> </div>
</div>
</div><!-- end task list box --> </div><!-- end task list box -->
</div> </div>
@ -529,7 +532,7 @@ window.onresize = () => {
listBoxHeight.value = window.innerHeight - 40 listBoxHeight.value = window.innerHeight - 40
mjBoxHeight.value = window.innerHeight - 150 mjBoxHeight.value = window.innerHeight - 150
} }
const samplers = ["Euler a", "Euler", "DPM++ 2S a Karras", "DPM++ 2M Karras", "DPM++ SDE Karras", "DPM++ 2M SDE Karras"] const samplers = ["Euler a", "DPM++ 2S a Karras", "DPM++ 2M Karras", "DPM++ SDE Karras", "DPM++ 2M SDE Karras"]
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"] const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"]
const params = ref({ const params = ref({
width: 1024, width: 1024,
@ -580,18 +583,73 @@ const translatePrompt = () => {
}) })
} }
const socket = ref(null)
const userId = ref(0)
const heartbeatHandle = ref(null)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
//
const sendHeartbeat = () => {
clearTimeout(heartbeatHandle.value)
new Promise((resolve, reject) => {
if (socket.value !== null) {
socket.value.send(JSON.stringify({type: "heartbeat", content: "ping"}))
}
resolve("success")
}).then(() => {
heartbeatHandle.value = setTimeout(() => sendHeartbeat(), 5000)
});
}
const _socket = new WebSocket(host + `/api/sd/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
//
sendHeartbeat()
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
fetchRunningJobs(userId.value)
fetchFinishJobs(userId.value)
}
});
_socket.addEventListener('close', () => {
connect()
});
}
onMounted(() => { onMounted(() => {
checkSession().then(user => { checkSession().then(user => {
imgCalls.value = user['img_calls'] imgCalls.value = user['img_calls']
userId.value = user.id
fetchRunningJobs(user.id) fetchRunningJobs(user.id)
fetchFinishJobs(user.id) fetchFinishJobs(user.id)
connect()
}).catch(() => { }).catch(() => {
router.push('/login') router.push('/login')
}); });
const clipboard = new Clipboard('.copy-prompt');
clipboard.on('success', () => {
ElMessage.success("复制成功!");
})
const fetchRunningJobs = (userId) => { clipboard.on('error', () => {
ElMessage.error('复制失败!');
})
})
const fetchRunningJobs = (userId) => {
// //
httpGet(`/api/sd/jobs?status=0&user_id=${userId}`).then(res => { httpGet(`/api/sd/jobs?status=0&user_id=${userId}`).then(res => {
const jobs = res.data const jobs = res.data
@ -610,51 +668,19 @@ onMounted(() => {
_jobs.push(jobs[i]) _jobs.push(jobs[i])
} }
runningJobs.value = _jobs runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 1000)
}).catch(e => { }).catch(e => {
ElMessage.error("获取任务失败:" + e.message) ElMessage.error("获取任务失败:" + e.message)
setTimeout(() => fetchRunningJobs(userId), 5000)
}) })
} }
// //
const fetchFinishJobs = (userId) => { const fetchFinishJobs = (userId) => {
httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => { httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
if (finishedJobs.value.length === 0 || res.data.length > finishedJobs.value.length) {
finishedJobs.value = res.data finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 1000)
return
}
// check if the img url is changed
const list = res.data
let changed = false
for (let i = 0; i < list.length; i++) {
if (list[i]["img_url"] !== finishedJobs.value[i]["img_url"]) {
changed = true
break
}
}
if (changed) {
finishedJobs.value = list
}
setTimeout(() => fetchFinishJobs(userId), 1000)
}).catch(e => { }).catch(e => {
ElMessage.error("获取任务失败:" + e.message) ElMessage.error("获取任务失败:" + e.message)
setTimeout(() => fetchFinishJobs(userId), 5000)
}) })
} }
const clipboard = new Clipboard('.copy-prompt');
clipboard.on('success', () => {
ElMessage.success("复制成功!");
})
clipboard.on('error', () => {
ElMessage.error('复制失败!');
})
})
// //
@ -697,7 +723,7 @@ const removeImage = (event, item) => {
type: 'warning', type: 'warning',
} }
).then(() => { ).then(() => {
httpPost("/api/sd/remove", {id: item.id, img_url: item.img_url}).then(() => { httpPost("/api/sd/remove", {id: item.id, img_url: item.img_url, user_id: userId.value}).then(() => {
ElMessage.success("任务删除成功") ElMessage.success("任务删除成功")
}).catch(e => { }).catch(e => {
ElMessage.error("任务删除失败:" + e.message) ElMessage.error("任务删除失败:" + e.message)

View File

@ -7,6 +7,8 @@
@keyup="searchChat($event)"></el-input> @keyup="searchChat($event)"></el-input>
<el-input v-model="data.chat.query.title" placeholder="对话标题" class="handle-input mr10" <el-input v-model="data.chat.query.title" placeholder="对话标题" class="handle-input mr10"
@keyup="searchChat($event)"></el-input> @keyup="searchChat($event)"></el-input>
<el-input v-model="data.chat.query.model" placeholder="模型" class="handle-input mr10"
@keyup="searchChat($event)"></el-input>
<el-date-picker <el-date-picker
v-model="data.chat.query.created_at" v-model="data.chat.query.created_at"
type="daterange" type="daterange"
@ -97,7 +99,7 @@
</template> </template>
</el-table-column> </el-table-column>
<el-table-column prop="token" label="算力"/> <el-table-column prop="token" label="消耗算力"/>
<el-table-column label="创建时间"> <el-table-column label="创建时间">
<template #default="scope"> <template #default="scope">