mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 10:13:44 +08:00
sd websocket refactor is finished
This commit is contained in:
@@ -43,12 +43,14 @@ type MjTask struct {
|
||||
type SdTask struct {
|
||||
Id int `json:"id"` // job 数据库ID
|
||||
Type TaskType `json:"type"`
|
||||
ClientId string `json:"client_id"`
|
||||
UserId int `json:"user_id"`
|
||||
Params SdTaskParams `json:"params"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
}
|
||||
|
||||
type SdTaskParams struct {
|
||||
ClientId string `json:"client_id"` // 客户端ID
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
NegPrompt string `json:"neg_prompt"` // 反向提示词
|
||||
|
||||
@@ -144,17 +144,13 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.sdService.PushTask(types.SdTask{
|
||||
Id: int(job.Id),
|
||||
Type: types.TaskImage,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
Id: int(job.Id),
|
||||
ClientId: data.ClientId,
|
||||
Type: types.TaskImage,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
})
|
||||
|
||||
client := h.sdService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
// update user's power
|
||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
|
||||
@@ -34,17 +34,17 @@ type Service struct {
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
leveldb *store.LevelDB
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C(),
|
||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||
db: db,
|
||||
leveldb: levelDB,
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
}
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func (s *Service) Run() {
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
// 通知前端,任务失败
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -213,7 +213,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
return nil
|
||||
@@ -223,7 +223,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
if err == nil && resp.Progress > 0 {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
// 保存预览图片数据
|
||||
if resp.CurrentImage != "" {
|
||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||
@@ -267,14 +267,11 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChSd, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -8,9 +8,10 @@ const (
|
||||
)
|
||||
|
||||
type NotifyMessage struct {
|
||||
UserId int `json:"user_id"`
|
||||
JobId int `json:"job_id"`
|
||||
Message string `json:"message"`
|
||||
UserId int `json:"user_id"`
|
||||
ClientId string `json:"client_id"`
|
||||
JobId int `json:"job_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
|
||||
|
||||
Reference in New Issue
Block a user