refactor: add midjourney pool implementation, add translate prompt for mj drawing

This commit is contained in:
RockYang
2023-12-13 16:38:27 +08:00
parent 8f4d20e411
commit 6d71f24f75
16 changed files with 226 additions and 272 deletions

View File

@@ -19,17 +19,18 @@ var logger = logger2.GetLogger()
type Bot struct {
config *types.MidJourneyConfig
bot *discordgo.Session
name string
service *Service
}
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken)
func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.BotToken)
if err != nil {
return nil, err
}
if config.ProxyURL != "" {
proxy, _ := url.Parse(config.ProxyURL)
if proxy != "" {
proxy, _ := url.Parse(proxy)
discord.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
@@ -41,8 +42,9 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
}
return &Bot{
config: &config.MjConfigs,
config: config,
bot: discord,
name: name,
service: service,
}, nil
}
@@ -52,13 +54,13 @@ func (b *Bot) Run() error {
b.bot.AddHandler(b.messageCreate)
b.bot.AddHandler(b.messageUpdate)
logger.Info("Starting MidJourney Bot...")
logger.Infof("Starting MidJourney %s", b.name)
err := b.bot.Open()
if err != nil {
logger.Error("Error opening Discord connection:", err)
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
return err
}
logger.Info("Starting MidJourney Bot successfully!")
logger.Infof("Starting MidJourney %s successfully!", b.name)
return nil
}

View File

@@ -2,7 +2,6 @@ package mj
import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"github.com/imroc/req/v3"
"time"
@@ -18,9 +17,10 @@ type Client struct {
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second)
// set proxy URL
if utils.IsEmptyValue(proxy) {
if proxy != "" {
client.SetProxyURL(proxy)
}
logger.Info(proxy)
return &Client{client: client, config: config}
}

View File

@@ -4,35 +4,63 @@ import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"fmt"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []Service
services []*Service
taskQueue *store.RedisQueue
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
// create mj client and service
for _, config := range appConfig.MjConfigs {
for k, config := range appConfig.MjConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(&config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService()
service := NewService(name, queue, 4, 600, db, client, manager, appConfig)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
if err != nil {
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
taskQueue: queue,
services: services,
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@@ -27,16 +27,17 @@ type Service struct {
snowflake *service.Snowflake
}
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
return &Service{
name: name,
db: db,
taskQueue: queue,
client: client,
uploadManager: manager,
taskTimeout: timeout,
proxyURL: config.ProxyURL,
taskStartTimes: make(map[int]time.Time, 0),
name: name,
db: db,
taskQueue: queue,
client: client,
uploadManager: manager,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
proxyURL: config.ProxyURL,
taskStartTimes: make(map[int]time.Time, 0),
}
}
@@ -58,7 +59,7 @@ func (s *Service) Run() {
continue
}
logger.Infof("handle a new MidJourney task: %+v", task)
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type {
case types.TaskImage:
err = s.client.Imagine(task.Prompt)
@@ -92,11 +93,14 @@ func (s *Service) canHandleTask() bool {
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
}
@@ -121,15 +125,17 @@ func (s *Service) Notify(data CBReq) {
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL // save origin image
job.OrgURL = data.Image.URL
// upload image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
if data.Status == Finished {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imgURL
}
job.ImgURL = imgURL
res = s.db.Updates(&job)
if res.Error != nil {

View File

@@ -23,7 +23,7 @@ func NewSnowflake() *Snowflake {
}
// Next 生成一个新的唯一ID
func (s *Snowflake) Next() (string, error) {
func (s *Snowflake) Next(raw bool) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -43,6 +43,9 @@ func (s *Snowflake) Next() (string, error) {
s.lastTimestamp = timestamp
id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence)
if raw {
return fmt.Sprintf("%d", id), nil
}
now := time.Now()
return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil
}