feat: change midjourney origin implements, replace midjourney bot with midjourney-proxy

This commit is contained in:
RockYang
2024-03-27 18:57:15 +08:00
parent 9794d67eaa
commit 360fea4085
13 changed files with 420 additions and 904 deletions

View File

@@ -2,13 +2,12 @@ package mj
import (
"chatplus/core/types"
"chatplus/service/mj/plus"
logger2 "chatplus/logger"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"fmt"
"github.com/go-redis/redis/v8"
"strings"
"time"
"gorm.io/gorm"
@@ -16,7 +15,7 @@ import (
// ServicePool Mj service pool
type ServicePool struct {
services []interface{}
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
@@ -24,8 +23,10 @@ type ServicePool struct {
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]interface{}, 0)
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
@@ -33,45 +34,26 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
if config.Enabled == false {
continue
}
client := plus.NewClient(config)
name := fmt.Sprintf("mj-service-plus-%d", k)
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
cli := NewPlusClient(config)
name := fmt.Sprintf("mj-plus-service-%d", k)
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
go func() {
servicePlus.Run()
service.Run()
}()
services = append(services, servicePlus)
services = append(services, service)
}
if len(services) == 0 {
// create mj client and service
for k, config := range appConfig.MjConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
if err != nil {
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
go func() {
service.Run()
}()
services = append(services, service)
for k, config := range appConfig.MjProxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
@@ -92,11 +74,11 @@ func (p *ServicePool) CheckTaskNotify() {
if err != nil {
continue
}
client := p.Clients.Get(userId)
if client == nil {
cli := p.Clients.Get(userId)
if cli == nil {
continue
}
err = client.Send([]byte("Task Updated"))
err = cli.Send([]byte("Task Updated"))
if err != nil {
continue
}
@@ -122,10 +104,10 @@ func (p *ServicePool) DownloadImages() {
logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string
var err error
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
} else {
@@ -141,11 +123,11 @@ func (p *ServicePool) DownloadImages() {
v.ImgURL = imgURL
p.db.Updates(&v)
client := p.Clients.Get(uint(v.UserId))
if client == nil {
cli := p.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = client.Send([]byte("Task Updated"))
err = cli.Send([]byte("Task Updated"))
if err != nil {
continue
}
@@ -167,25 +149,6 @@ func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
func (p *ServicePool) Notify(data plus.CBReq) error {
logger.Debugf("收到任务回调:%+v", data)
var job model.MidJourneyJob
res := p.db.Where("task_id = ?", data.Id).First(&job)
if res.Error != nil {
return fmt.Errorf("非法任务:%s", data.Id)
}
// 任务已经拉取完成
if job.Progress == 100 {
return nil
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
return servicePlus.Notify(job)
}
return nil
}
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
@@ -222,11 +185,7 @@ func (p *ServicePool) SyncTaskProgress() {
}
}
if !strings.HasPrefix(job.ChannelId, "mj-service-plus") {
continue
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
@@ -236,12 +195,10 @@ func (p *ServicePool) SyncTaskProgress() {
}()
}
func (p *ServicePool) getServicePlus(name string) *plus.Service {
func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services {
if servicePlus, ok := s.(*plus.Service); ok {
if servicePlus.Name == name {
return servicePlus
}
if s.Name == name {
return s
}
}
return nil