mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 17:26:38 +08:00
sd websocket refactor is finished
This commit is contained in:
parent
2debe7e927
commit
8fffa60569
@ -43,12 +43,14 @@ type MjTask struct {
|
|||||||
type SdTask struct {
|
type SdTask struct {
|
||||||
Id int `json:"id"` // job 数据库ID
|
Id int `json:"id"` // job 数据库ID
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
|
ClientId string `json:"client_id"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Params SdTaskParams `json:"params"`
|
Params SdTaskParams `json:"params"`
|
||||||
RetryCount int `json:"retry_count"`
|
RetryCount int `json:"retry_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SdTaskParams struct {
|
type SdTaskParams struct {
|
||||||
|
ClientId string `json:"client_id"` // 客户端ID
|
||||||
TaskId string `json:"task_id"`
|
TaskId string `json:"task_id"`
|
||||||
Prompt string `json:"prompt"` // 提示词
|
Prompt string `json:"prompt"` // 提示词
|
||||||
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
||||||
|
@ -144,17 +144,13 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.sdService.PushTask(types.SdTask{
|
h.sdService.PushTask(types.SdTask{
|
||||||
Id: int(job.Id),
|
Id: int(job.Id),
|
||||||
Type: types.TaskImage,
|
ClientId: data.ClientId,
|
||||||
Params: params,
|
Type: types.TaskImage,
|
||||||
UserId: userId,
|
Params: params,
|
||||||
|
UserId: userId,
|
||||||
})
|
})
|
||||||
|
|
||||||
client := h.sdService.Clients.Get(uint(job.UserId))
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// update user's power
|
// update user's power
|
||||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||||
Type: types.PowerConsume,
|
Type: types.PowerConsume,
|
||||||
|
@ -34,17 +34,17 @@ type Service struct {
|
|||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
leveldb *store.LevelDB
|
leveldb *store.LevelDB
|
||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
wsService *service.WebsocketService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
|
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
httpClient: req.C(),
|
httpClient: req.C(),
|
||||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||||
db: db,
|
db: db,
|
||||||
leveldb: levelDB,
|
leveldb: levelDB,
|
||||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
wsService: wsService,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -90,7 +90,7 @@ func (s *Service) Run() {
|
|||||||
"err_msg": err.Error(),
|
"err_msg": err.Error(),
|
||||||
})
|
})
|
||||||
// 通知前端,任务失败
|
// 通知前端,任务失败
|
||||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -213,7 +213,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
|
|
||||||
// task finished
|
// task finished
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||||
// 从 leveldb 中删除预览图片数据
|
// 从 leveldb 中删除预览图片数据
|
||||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||||
return nil
|
return nil
|
||||||
@ -223,7 +223,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
if err == nil && resp.Progress > 0 {
|
if err == nil && resp.Progress > 0 {
|
||||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||||
// 发送更新状态信号
|
// 发送更新状态信号
|
||||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||||
// 保存预览图片数据
|
// 保存预览图片数据
|
||||||
if resp.CurrentImage != "" {
|
if resp.CurrentImage != "" {
|
||||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||||
@ -267,14 +267,11 @@ func (s *Service) CheckTaskNotify() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client := s.Clients.Get(uint(message.UserId))
|
client := s.wsService.Clients.Get(message.ClientId)
|
||||||
if client == nil {
|
if client == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = client.Send([]byte(message.Message))
|
utils.SendChannelMsg(client, types.ChSd, message.Message)
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -8,9 +8,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type NotifyMessage struct {
|
type NotifyMessage struct {
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
JobId int `json:"job_id"`
|
ClientId string `json:"client_id"`
|
||||||
Message string `json:"message"`
|
JobId int `json:"job_id"`
|
||||||
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
|
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
|
||||||
|
@ -71,12 +71,12 @@ const connect = () => {
|
|||||||
handler.value = setInterval(() => {
|
handler.value = setInterval(() => {
|
||||||
_socket.send(JSON.stringify({"type":"ping"}))
|
_socket.send(JSON.stringify({"type":"ping"}))
|
||||||
},5000)
|
},5000)
|
||||||
|
})
|
||||||
|
|
||||||
for (const key in store.messageHandlers) {
|
for (const key in store.messageHandlers) {
|
||||||
console.log(key, store.messageHandlers[key])
|
console.log(key, store.messageHandlers[key])
|
||||||
store.setMessageHandler(store.messageHandlers[key])
|
store.setMessageHandler(store.messageHandlers[key])
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
_socket.addEventListener('close', () => {
|
_socket.addEventListener('close', () => {
|
||||||
store.setSocket(null)
|
store.setSocket(null)
|
||||||
|
@ -487,11 +487,11 @@
|
|||||||
|
|
||||||
<script setup>
|
<script setup>
|
||||||
import {nextTick, onMounted, onUnmounted, ref} from "vue"
|
import {nextTick, onMounted, onUnmounted, ref} from "vue"
|
||||||
import {Delete, DocumentCopy, InfoFilled, Orange, Picture} from "@element-plus/icons-vue";
|
import {Delete, DocumentCopy, InfoFilled, Orange} from "@element-plus/icons-vue";
|
||||||
import {httpGet, httpPost} from "@/utils/http";
|
import {httpGet, httpPost} from "@/utils/http";
|
||||||
import {ElMessage, ElMessageBox, ElNotification} from "element-plus";
|
import {ElMessage, ElMessageBox} from "element-plus";
|
||||||
import Clipboard from "clipboard";
|
import Clipboard from "clipboard";
|
||||||
import {checkSession, getSystemInfo} from "@/store/cache";
|
import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
|
||||||
import {useRouter} from "vue-router";
|
import {useRouter} from "vue-router";
|
||||||
import {getSessionId} from "@/store/session";
|
import {getSessionId} from "@/store/session";
|
||||||
import {useSharedStore} from "@/store/sharedata";
|
import {useSharedStore} from "@/store/sharedata";
|
||||||
@ -520,6 +520,7 @@ const samplers = ["Euler a", "DPM++ 2S a", "DPM++ 2M", "DPM++ SDE", "DPM++ 2M SD
|
|||||||
const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"]
|
const schedulers = ["Automatic", "Karras", "Exponential", "Uniform"]
|
||||||
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"]
|
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"]
|
||||||
const params = ref({
|
const params = ref({
|
||||||
|
client_id: getClientId(),
|
||||||
width: 1024,
|
width: 1024,
|
||||||
height: 1024,
|
height: 1024,
|
||||||
sampler: samplers[0],
|
sampler: samplers[0],
|
||||||
@ -547,46 +548,7 @@ if (_params) {
|
|||||||
const power = ref(0)
|
const power = ref(0)
|
||||||
const sdPower = ref(0) // 画一张 SD 图片消耗算力
|
const sdPower = ref(0) // 画一张 SD 图片消耗算力
|
||||||
|
|
||||||
const socket = ref(null)
|
|
||||||
const userId = ref(0)
|
const userId = ref(0)
|
||||||
const connect = () => {
|
|
||||||
let host = process.env.VUE_APP_WS_HOST
|
|
||||||
if (host === '') {
|
|
||||||
if (location.protocol === 'https:') {
|
|
||||||
host = 'wss://' + location.host;
|
|
||||||
} else {
|
|
||||||
host = 'ws://' + location.host;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const _socket = new WebSocket(host + `/api/sd/client?user_id=${userId.value}`);
|
|
||||||
_socket.addEventListener('open', () => {
|
|
||||||
socket.value = _socket;
|
|
||||||
});
|
|
||||||
|
|
||||||
_socket.addEventListener('message', event => {
|
|
||||||
if (event.data instanceof Blob) {
|
|
||||||
const reader = new FileReader();
|
|
||||||
reader.readAsText(event.data, "UTF-8")
|
|
||||||
reader.onload = () => {
|
|
||||||
const message = String(reader.result)
|
|
||||||
if (message === "FINISH" || message === "FAIL") {
|
|
||||||
page.value = 0
|
|
||||||
isOver.value = false
|
|
||||||
fetchFinishJobs()
|
|
||||||
}
|
|
||||||
nextTick(() => fetchRunningJobs())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
_socket.addEventListener('close', () => {
|
|
||||||
if (socket.value !== null) {
|
|
||||||
connect()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const clipboard = ref(null)
|
const clipboard = ref(null)
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
initData()
|
initData()
|
||||||
@ -605,14 +567,25 @@ onMounted(() => {
|
|||||||
}).catch(e => {
|
}).catch(e => {
|
||||||
ElMessage.error("获取系统配置失败:" + e.message)
|
ElMessage.error("获取系统配置失败:" + e.message)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
store.addMessageHandler("sd",(data) => {
|
||||||
|
// 丢弃无关消息
|
||||||
|
if (data.channel !== "sd" || data.clientId !== getClientId()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.body === "FINISH" || data.body === "FAIL") {
|
||||||
|
page.value = 0
|
||||||
|
isOver.value = false
|
||||||
|
fetchFinishJobs()
|
||||||
|
}
|
||||||
|
nextTick(() => fetchRunningJobs())
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
clipboard.value.destroy()
|
clipboard.value.destroy()
|
||||||
if (socket.value !== null) {
|
|
||||||
socket.value.close()
|
|
||||||
socket.value = null
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -624,7 +597,6 @@ const initData = () => {
|
|||||||
page.value = 0
|
page.value = 0
|
||||||
fetchRunningJobs()
|
fetchRunningJobs()
|
||||||
fetchFinishJobs()
|
fetchFinishJobs()
|
||||||
connect()
|
|
||||||
}).catch(() => {
|
}).catch(() => {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -694,6 +666,7 @@ const generate = () => {
|
|||||||
httpPost("/api/sd/image", params.value).then(() => {
|
httpPost("/api/sd/image", params.value).then(() => {
|
||||||
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
|
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
|
||||||
power.value -= sdPower.value
|
power.value -= sdPower.value
|
||||||
|
fetchRunningJobs()
|
||||||
}).catch(e => {
|
}).catch(e => {
|
||||||
ElMessage.error("任务推送失败:" + e.message)
|
ElMessage.error("任务推送失败:" + e.message)
|
||||||
})
|
})
|
||||||
|
@ -194,7 +194,7 @@ const fetchData = () => {
|
|||||||
const add = function () {
|
const add = function () {
|
||||||
showDialog.value = true
|
showDialog.value = true
|
||||||
title.value = "新增 API KEY"
|
title.value = "新增 API KEY"
|
||||||
item.value = {enabled: true,api_url: "https://api.chat-plus.net"}
|
item.value = {enabled: true,api_url: "https://api.geekai.pro"}
|
||||||
}
|
}
|
||||||
|
|
||||||
const edit = function (row) {
|
const edit = function (row) {
|
||||||
|
Loading…
Reference in New Issue
Block a user