websocket api refactor is ready

This commit is contained in:
RockYang
2024-09-29 19:28:47 +08:00
parent 8093a3eeb2
commit 1a1734abf0
19 changed files with 210 additions and 464 deletions

View File

@@ -33,7 +33,6 @@ type Service struct {
notifyQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
leveldb *store.LevelDB
wsService *service.WebsocketService
}
@@ -43,7 +42,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db,
leveldb: levelDB,
wsService: wsService,
uploadManager: manager,
}
@@ -62,7 +60,7 @@ func (s *Service) Run() {
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
if err == nil {
task.Params.Prompt = content
} else {
@@ -72,7 +70,7 @@ func (s *Service) Run() {
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
if err == nil {
task.Params.NegPrompt = content
} else {
@@ -126,9 +124,8 @@ type Txt2ImgResp struct {
// TaskProgressResp 任务进度响应实体
type TaskProgressResp struct {
Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"`
CurrentImage string `json:"current_image"`
Progress float64 `json:"progress"`
EtaRelative float64 `json:"eta_relative"`
}
// Txt2Img 文生图 API
@@ -214,8 +211,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
default:
err, resp := s.checkTaskProgress(apiKey)
@@ -224,10 +219,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
}
}
time.Sleep(time.Second)
}
@@ -267,6 +258,7 @@ func (s *Service) CheckTaskNotify() {
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue