mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-12 12:13:46 +08:00
mj websocket refactor is ready
This commit is contained in:
@@ -24,8 +24,9 @@ const (
|
||||
|
||||
// MjTask MidJourney 任务
|
||||
type MjTask struct {
|
||||
Id uint `json:"id"`
|
||||
TaskId string `json:"task_id"`
|
||||
Id uint `json:"id"` // 任务ID
|
||||
TaskId string `json:"task_id"` // 中转任务ID
|
||||
ClientId string `json:"client_id"`
|
||||
ImgArr []string `json:"img_arr"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
|
||||
@@ -8,7 +8,6 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
@@ -67,6 +66,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
var data struct {
|
||||
TaskType string `json:"task_type"`
|
||||
ClientId string `json:"client_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegPrompt string `json:"neg_prompt"`
|
||||
Rate string `json:"rate"`
|
||||
@@ -177,6 +177,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
ClientId: data.ClientId,
|
||||
TaskId: taskId,
|
||||
Type: types.TaskType(data.TaskType),
|
||||
Prompt: data.Prompt,
|
||||
@@ -187,11 +188,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.mjService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
// update user's power
|
||||
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
@@ -208,6 +204,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
|
||||
type reqVo struct {
|
||||
Index int `json:"index"`
|
||||
ClientId string `json:"client_id"`
|
||||
ChannelId string `json:"channel_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
MessageHash string `json:"message_hash"`
|
||||
@@ -244,6 +241,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
ClientId: data.ClientId,
|
||||
Type: types.TaskUpscale,
|
||||
UserId: userId,
|
||||
ChannelId: data.ChannelId,
|
||||
@@ -253,11 +251,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.mjService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
// update user's power
|
||||
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
@@ -305,6 +298,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
Type: types.TaskVariation,
|
||||
ClientId: data.ClientId,
|
||||
UserId: userId,
|
||||
Index: data.Index,
|
||||
ChannelId: data.ChannelId,
|
||||
@@ -313,11 +307,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.mjService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "mid-journey",
|
||||
@@ -397,14 +386,6 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
|
||||
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
return nil, vo.NewPage(total, page, pageSize, jobs)
|
||||
@@ -449,11 +430,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
client := h.mjService.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
|
||||
@@ -28,18 +28,20 @@ type Service struct {
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
wsService *service.WebsocketService
|
||||
uploaderManager *oss.UploaderManager
|
||||
clientIds map[uint]string
|
||||
}
|
||||
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service {
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, wsService *service.WebsocketService) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
|
||||
client: client,
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
wsService: wsService,
|
||||
uploaderManager: manager,
|
||||
clientIds: map[uint]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +79,7 @@ func (s *Service) Run() {
|
||||
if task.Mode == "" {
|
||||
task.Mode = "fast"
|
||||
}
|
||||
s.clientIds[task.Id] = task.ClientId
|
||||
|
||||
var job model.MidJourneyJob
|
||||
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||||
@@ -119,7 +122,7 @@ func (s *Service) Run() {
|
||||
// update the task progress
|
||||
s.db.Updates(&job)
|
||||
// 任务失败,通知前端
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
@@ -166,14 +169,11 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
cli := s.Clients.Get(uint(message.UserId))
|
||||
if cli == nil {
|
||||
continue
|
||||
}
|
||||
err = cli.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
utils.SendChannelMsg(client, types.ChMj, message.Message)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -211,14 +211,11 @@ func (s *Service) DownloadImages() {
|
||||
v.ImgURL = imgURL
|
||||
s.db.Updates(&v)
|
||||
|
||||
cli := s.Clients.Get(uint(v.UserId))
|
||||
if cli == nil {
|
||||
continue
|
||||
}
|
||||
err = cli.Send([]byte(service.TaskStatusFinished))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[v.Id],
|
||||
UserId: v.UserId,
|
||||
JobId: int(v.Id),
|
||||
Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
@@ -264,7 +261,11 @@ func (s *Service) SyncTaskProgress() {
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
logger.Errorf("task failed: %v", task.FailReason)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -289,7 +290,11 @@ func (s *Service) SyncTaskProgress() {
|
||||
if job.Progress == 100 {
|
||||
message = service.TaskStatusFinished
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{
|
||||
ClientId: s.clientIds[job.Id],
|
||||
UserId: job.UserId,
|
||||
JobId: int(job.Id),
|
||||
Message: message})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user