refactor: add midjourney pool implementation, add translate prompt for mj drawing

This commit is contained in:
RockYang
2023-12-13 16:38:27 +08:00
parent 8f4d20e411
commit 6d71f24f75
16 changed files with 226 additions and 272 deletions

View File

@@ -27,16 +27,17 @@ type Service struct {
snowflake *service.Snowflake
}
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
return &Service{
name: name,
db: db,
taskQueue: queue,
client: client,
uploadManager: manager,
taskTimeout: timeout,
proxyURL: config.ProxyURL,
taskStartTimes: make(map[int]time.Time, 0),
name: name,
db: db,
taskQueue: queue,
client: client,
uploadManager: manager,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
proxyURL: config.ProxyURL,
taskStartTimes: make(map[int]time.Time, 0),
}
}
@@ -58,7 +59,7 @@ func (s *Service) Run() {
continue
}
logger.Infof("handle a new MidJourney task: %+v", task)
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type {
case types.TaskImage:
err = s.client.Imagine(task.Prompt)
@@ -92,11 +93,14 @@ func (s *Service) canHandleTask() bool {
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
}
@@ -121,15 +125,17 @@ func (s *Service) Notify(data CBReq) {
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL // save origin image
job.OrgURL = data.Image.URL
// upload image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
if data.Status == Finished {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imgURL
}
job.ImgURL = imgURL
res = s.db.Updates(&job)
if res.Error != nil {