mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 10:13:44 +08:00
refactor: add midjourney pool implementation, add translate prompt for mj drawing
This commit is contained in:
@@ -19,17 +19,18 @@ var logger = logger2.GetLogger()
|
||||
type Bot struct {
|
||||
config *types.MidJourneyConfig
|
||||
bot *discordgo.Session
|
||||
name string
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
||||
discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken)
|
||||
func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) {
|
||||
discord, err := discordgo.New("Bot " + config.BotToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ProxyURL != "" {
|
||||
proxy, _ := url.Parse(config.ProxyURL)
|
||||
if proxy != "" {
|
||||
proxy, _ := url.Parse(proxy)
|
||||
discord.Client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
@@ -41,8 +42,9 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
||||
}
|
||||
|
||||
return &Bot{
|
||||
config: &config.MjConfigs,
|
||||
config: config,
|
||||
bot: discord,
|
||||
name: name,
|
||||
service: service,
|
||||
}, nil
|
||||
}
|
||||
@@ -52,13 +54,13 @@ func (b *Bot) Run() error {
|
||||
b.bot.AddHandler(b.messageCreate)
|
||||
b.bot.AddHandler(b.messageUpdate)
|
||||
|
||||
logger.Info("Starting MidJourney Bot...")
|
||||
logger.Infof("Starting MidJourney %s", b.name)
|
||||
err := b.bot.Open()
|
||||
if err != nil {
|
||||
logger.Error("Error opening Discord connection:", err)
|
||||
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Starting MidJourney Bot successfully!")
|
||||
logger.Infof("Starting MidJourney %s successfully!", b.name)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"time"
|
||||
@@ -18,9 +17,10 @@ type Client struct {
|
||||
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
|
||||
client := req.C().SetTimeout(10 * time.Second)
|
||||
// set proxy URL
|
||||
if utils.IsEmptyValue(proxy) {
|
||||
if proxy != "" {
|
||||
client.SetProxyURL(proxy)
|
||||
}
|
||||
logger.Info(proxy)
|
||||
return &Client{client: client, config: config}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,35 +4,63 @@ import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ServicePool Mj service pool
|
||||
type ServicePool struct {
|
||||
services []Service
|
||||
services []*Service
|
||||
taskQueue *store.RedisQueue
|
||||
}
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||
services := make([]*Service, 0)
|
||||
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
||||
// create mj client and service
|
||||
for _, config := range appConfig.MjConfigs {
|
||||
for k, config := range appConfig.MjConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
// create mj client
|
||||
client := NewClient(&config, appConfig.ProxyURL)
|
||||
|
||||
name := fmt.Sprintf("MjService-%d", k)
|
||||
// create mj service
|
||||
service := NewService()
|
||||
service := NewService(name, queue, 4, 600, db, client, manager, appConfig)
|
||||
botName := fmt.Sprintf("MjBot-%d", k)
|
||||
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err = bot.Run()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// run mj service
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
|
||||
services = append(services, service)
|
||||
}
|
||||
|
||||
return &ServicePool{
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
taskQueue: queue,
|
||||
services: services,
|
||||
}
|
||||
}
|
||||
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (p *ServicePool) PushTask(task types.MjTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
p.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
// HasAvailableService check if has available mj service in pool
|
||||
func (p *ServicePool) HasAvailableService() bool {
|
||||
return len(p.services) > 0
|
||||
}
|
||||
|
||||
@@ -27,16 +27,17 @@ type Service struct {
|
||||
snowflake *service.Snowflake
|
||||
}
|
||||
|
||||
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
||||
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
||||
return &Service{
|
||||
name: name,
|
||||
db: db,
|
||||
taskQueue: queue,
|
||||
client: client,
|
||||
uploadManager: manager,
|
||||
taskTimeout: timeout,
|
||||
proxyURL: config.ProxyURL,
|
||||
taskStartTimes: make(map[int]time.Time, 0),
|
||||
name: name,
|
||||
db: db,
|
||||
taskQueue: queue,
|
||||
client: client,
|
||||
uploadManager: manager,
|
||||
taskTimeout: timeout,
|
||||
maxHandleTaskNum: maxTaskNum,
|
||||
proxyURL: config.ProxyURL,
|
||||
taskStartTimes: make(map[int]time.Time, 0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +59,7 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("handle a new MidJourney task: %+v", task)
|
||||
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
|
||||
switch task.Type {
|
||||
case types.TaskImage:
|
||||
err = s.client.Imagine(task.Prompt)
|
||||
@@ -92,11 +93,14 @@ func (s *Service) canHandleTask() bool {
|
||||
return handledNum < s.maxHandleTaskNum
|
||||
}
|
||||
|
||||
// remove the expired tasks
|
||||
func (s *Service) checkTasks() {
|
||||
for k, t := range s.taskStartTimes {
|
||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||
delete(s.taskStartTimes, k)
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
// delete task from database
|
||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -121,15 +125,17 @@ func (s *Service) Notify(data CBReq) {
|
||||
job.Progress = data.Progress
|
||||
job.Prompt = data.Prompt
|
||||
job.Hash = data.Image.Hash
|
||||
job.OrgURL = data.Image.URL // save origin image
|
||||
job.OrgURL = data.Image.URL
|
||||
|
||||
// upload image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
if data.Status == Finished {
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
}
|
||||
job.ImgURL = imgURL
|
||||
}
|
||||
job.ImgURL = imgURL
|
||||
|
||||
res = s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
|
||||
@@ -23,7 +23,7 @@ func NewSnowflake() *Snowflake {
|
||||
}
|
||||
|
||||
// Next 生成一个新的唯一ID
|
||||
func (s *Snowflake) Next() (string, error) {
|
||||
func (s *Snowflake) Next(raw bool) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -43,6 +43,9 @@ func (s *Snowflake) Next() (string, error) {
|
||||
|
||||
s.lastTimestamp = timestamp
|
||||
id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence)
|
||||
if raw {
|
||||
return fmt.Sprintf("%d", id), nil
|
||||
}
|
||||
now := time.Now()
|
||||
return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user