websocket api refactor is ready

This commit is contained in:
RockYang
2024-09-29 19:28:47 +08:00
parent 00a8bc6784
commit e28a12a1ee
19 changed files with 210 additions and 464 deletions

View File

@@ -34,19 +34,21 @@ type Service struct {
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
userService *service.UserService
wsService *service.WebsocketService
clientIds map[uint]string
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService, wsService *service.WebsocketService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
wsService: wsService,
uploadManager: manager,
userService: userService,
clientIds: map[uint]string{},
}
}
@@ -67,6 +69,7 @@ func (s *Service) Run() {
continue
}
logger.Infof("handle a new DALL-E task: %+v", task)
s.clientIds[task.JobId] = task.ClientId
_, err = s.Image(task, false)
if err != nil {
logger.Errorf("error with image task: %v", err)
@@ -74,7 +77,7 @@ func (s *Service) Run() {
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
}
}
}()
@@ -111,7 +114,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
prompt := task.Prompt
// translate prompt
if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini", 0)
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
@@ -183,7 +186,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", fmt.Errorf("err with update database: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
var content string
if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
@@ -205,14 +208,13 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
utils.SendChannelMsg(client, types.ChDall, message.Message)
}
}()
}
@@ -284,6 +286,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil {
return "", err
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
s.notifyQueue.RPush(service.NotifyMessage{ClientId: s.clientIds[jobId], UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
return imgURL, nil
}