mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 19:23:42 +08:00
feat: change midjourney origin implements, replace midjourney bot with midjourney-proxy
This commit is contained in:
@@ -2,13 +2,12 @@ package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/mj/plus"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -16,7 +15,7 @@ import (
|
||||
|
||||
// ServicePool Mj service pool
|
||||
type ServicePool struct {
|
||||
services []interface{}
|
||||
services []*Service
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
@@ -24,8 +23,10 @@ type ServicePool struct {
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||
services := make([]interface{}, 0)
|
||||
services := make([]*Service, 0)
|
||||
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
||||
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
||||
|
||||
@@ -33,45 +34,26 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
client := plus.NewClient(config)
|
||||
name := fmt.Sprintf("mj-service-plus-%d", k)
|
||||
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
|
||||
cli := NewPlusClient(config)
|
||||
name := fmt.Sprintf("mj-plus-service-%d", k)
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||
go func() {
|
||||
servicePlus.Run()
|
||||
service.Run()
|
||||
}()
|
||||
services = append(services, servicePlus)
|
||||
services = append(services, service)
|
||||
}
|
||||
|
||||
if len(services) == 0 {
|
||||
// create mj client and service
|
||||
for k, config := range appConfig.MjConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
// create mj client
|
||||
client := NewClient(config, appConfig.ProxyURL)
|
||||
|
||||
name := fmt.Sprintf("MjService-%d", k)
|
||||
// create mj service
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
|
||||
botName := fmt.Sprintf("MjBot-%d", k)
|
||||
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err = bot.Run()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// run mj service
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
|
||||
services = append(services, service)
|
||||
for k, config := range appConfig.MjProxyConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
cli := NewProxyClient(config)
|
||||
name := fmt.Sprintf("mj-proxy-service-%d", k)
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
services = append(services, service)
|
||||
}
|
||||
|
||||
return &ServicePool{
|
||||
@@ -92,11 +74,11 @@ func (p *ServicePool) CheckTaskNotify() {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := p.Clients.Get(userId)
|
||||
if client == nil {
|
||||
cli := p.Clients.Get(userId)
|
||||
if cli == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte("Task Updated"))
|
||||
err = cli.Send([]byte("Task Updated"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -122,10 +104,10 @@ func (p *ServicePool) DownloadImages() {
|
||||
logger.Infof("try to download image: %s", v.OrgURL)
|
||||
var imgURL string
|
||||
var err error
|
||||
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
|
||||
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
|
||||
task, _ := servicePlus.Client.QueryTask(v.TaskId)
|
||||
if len(task.Buttons) > 0 {
|
||||
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
|
||||
v.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
|
||||
} else {
|
||||
@@ -141,11 +123,11 @@ func (p *ServicePool) DownloadImages() {
|
||||
v.ImgURL = imgURL
|
||||
p.db.Updates(&v)
|
||||
|
||||
client := p.Clients.Get(uint(v.UserId))
|
||||
if client == nil {
|
||||
cli := p.Clients.Get(uint(v.UserId))
|
||||
if cli == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte("Task Updated"))
|
||||
err = cli.Send([]byte("Task Updated"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -167,25 +149,6 @@ func (p *ServicePool) HasAvailableService() bool {
|
||||
return len(p.services) > 0
|
||||
}
|
||||
|
||||
func (p *ServicePool) Notify(data plus.CBReq) error {
|
||||
logger.Debugf("收到任务回调:%+v", data)
|
||||
var job model.MidJourneyJob
|
||||
res := p.db.Where("task_id = ?", data.Id).First(&job)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("非法任务:%s", data.Id)
|
||||
}
|
||||
|
||||
// 任务已经拉取完成
|
||||
if job.Progress == 100 {
|
||||
return nil
|
||||
}
|
||||
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
|
||||
return servicePlus.Notify(job)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (p *ServicePool) SyncTaskProgress() {
|
||||
go func() {
|
||||
@@ -222,11 +185,7 @@ func (p *ServicePool) SyncTaskProgress() {
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(job.ChannelId, "mj-service-plus") {
|
||||
continue
|
||||
}
|
||||
|
||||
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
|
||||
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
|
||||
_ = servicePlus.Notify(job)
|
||||
}
|
||||
}
|
||||
@@ -236,12 +195,10 @@ func (p *ServicePool) SyncTaskProgress() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *ServicePool) getServicePlus(name string) *plus.Service {
|
||||
func (p *ServicePool) getService(name string) *Service {
|
||||
for _, s := range p.services {
|
||||
if servicePlus, ok := s.(*plus.Service); ok {
|
||||
if servicePlus.Name == name {
|
||||
return servicePlus
|
||||
}
|
||||
if s.Name == name {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user