feat: support CDN reverse proxy for MidJourney and OpenAI API

This commit is contained in:
RockYang
2023-12-22 17:25:31 +08:00
parent de512a5ea2
commit 3ab930a107
19 changed files with 218 additions and 87 deletions

View File

@@ -6,9 +6,9 @@ import (
"chatplus/store"
"chatplus/store/model"
"fmt"
"github.com/go-redis/redis/v8"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
@@ -16,13 +16,16 @@ import (
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.MjConfigs {
if config.Enabled == false {
@@ -33,9 +36,9 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, queue, 4, 600, db, client)
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
if err != nil {
continue
}
@@ -54,13 +57,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
}
return &ServicePool{
taskQueue: queue,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
if err != nil {
continue
}
client := p.Clients.Get(userId)
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() {
go func() {
var items []model.MidJourneyJob
@@ -71,15 +93,21 @@ func (p *ServicePool) DownloadImages() {
}
// download images
for _, item := range items {
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(item.OrgURL, true)
for _, v := range items {
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
if err != nil {
logger.Error("error with download image: ", err)
continue
}
item.ImgURL = imgURL
p.db.Updates(&item)
v.ImgURL = imgURL
p.db.Updates(&v)
client := p.Clients.Get(uint(v.UserId))
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)