mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 18:23:45 +08:00
websocket api refactor is ready
This commit is contained in:
@@ -33,7 +33,6 @@ type Service struct {
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
leveldb *store.LevelDB
|
||||
wsService *service.WebsocketService
|
||||
}
|
||||
|
||||
@@ -43,7 +42,6 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelD
|
||||
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
|
||||
db: db,
|
||||
leveldb: levelDB,
|
||||
wsService: wsService,
|
||||
uploadManager: manager,
|
||||
}
|
||||
@@ -62,7 +60,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Params.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Params.Prompt = content
|
||||
} else {
|
||||
@@ -72,7 +70,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate negative prompt
|
||||
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini", 0)
|
||||
if err == nil {
|
||||
task.Params.NegPrompt = content
|
||||
} else {
|
||||
@@ -126,9 +124,8 @@ type Txt2ImgResp struct {
|
||||
|
||||
// TaskProgressResp 任务进度响应实体
|
||||
type TaskProgressResp struct {
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
CurrentImage string `json:"current_image"`
|
||||
Progress float64 `json:"progress"`
|
||||
EtaRelative float64 `json:"eta_relative"`
|
||||
}
|
||||
|
||||
// Txt2Img 文生图 API
|
||||
@@ -214,8 +211,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
return nil
|
||||
default:
|
||||
err, resp := s.checkTaskProgress(apiKey)
|
||||
@@ -224,10 +219,6 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(service.NotifyMessage{ClientId: task.ClientId, UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
|
||||
// 保存预览图片数据
|
||||
if resp.CurrentImage != "" {
|
||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
@@ -267,6 +258,7 @@ func (s *Service) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logger.Debugf("notify message: %+v", message)
|
||||
client := s.wsService.Clients.Get(message.ClientId)
|
||||
if client == nil {
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user