feat: support CDN reverse proxy for MidJourney and OpenAI API

This commit is contained in:
RockYang
2023-12-22 17:25:31 +08:00
parent 6754c8e85e
commit 18b7484c5b
19 changed files with 218 additions and 87 deletions

View File

@@ -4,7 +4,7 @@ import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"github.com/bwmarrin/discordgo"
discordgo "github.com/bg5t/mydiscordgo"
"github.com/gorilla/websocket"
"net/http"
"net/url"
@@ -17,35 +17,48 @@ import (
var logger = logger2.GetLogger()
type Bot struct {
config *types.MidJourneyConfig
config types.MidJourneyConfig
bot *discordgo.Session
name string
service *Service
}
func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.BotToken)
logger.Info(config.BotToken)
func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
bot, err := discordgo.New("Bot " + config.BotToken)
if err != nil {
logger.Error(err)
return nil, err
}
if proxy != "" {
proxy, _ := url.Parse(proxy)
discord.Client = &http.Client{
Transport: &http.Transport{
// use CDN reverse proxy
if config.UseCDN {
discordgo.SetEndpointDiscord(config.DiscordAPI)
discordgo.SetEndpointCDN(config.DiscordCDN)
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
bot.MjGateway = config.DiscordGateway + "/"
} else { // use proxy
discordgo.SetEndpointDiscord("https://discord.com")
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
discordgo.SetEndpointStatus("https://discord.com/api/v2/")
bot.MjGateway = "wss://gateway.discord.gg"
if proxy != "" {
proxy, _ := url.Parse(proxy)
bot.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
bot.Dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxy),
},
}
discord.Dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxy),
}
}
}
return &Bot{
config: config,
bot: discord,
bot: bot,
name: name,
service: service,
}, nil

View File

@@ -12,24 +12,32 @@ import (
type Client struct {
client *req.Client
config types.MidJourneyConfig
Config types.MidJourneyConfig
apiURL string
}
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second)
var apiURL string
// set proxy URL
if proxy != "" {
client.SetProxyURL(proxy)
if config.UseCDN {
apiURL = config.DiscordAPI + "/api/v9/interactions"
} else {
apiURL = "https://discord.com/api/v9/interactions"
if proxy != "" {
client.SetProxyURL(proxy)
}
}
return &Client{client: client, config: config}
return &Client{client: client, Config: config, apiURL: apiURL}
}
func (c *Client) Imagine(prompt string) error {
interactionsReq := &InteractionsRequest{
Type: 2,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
SessionID: SessionID,
Data: map[string]any{
"version": "1166847114203123795",
@@ -67,11 +75,10 @@ func (c *Client) Imagine(prompt string) error {
},
}
url := "https://discord.com/api/v9/interactions"
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
Post(url)
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %w%v", err, r.Err)
@@ -86,8 +93,8 @@ func (c *Client) Upscale(index int, messageId string, hash string) error {
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
SessionID: SessionID,
@@ -98,13 +105,12 @@ func (c *Client) Upscale(index int, messageId string, hash string) error {
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
url := "https://discord.com/api/v9/interactions"
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(url)
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}
@@ -118,8 +124,8 @@ func (c *Client) Variation(index int, messageId string, hash string) error {
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.config.GuildId,
ChannelID: c.config.ChanelId,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
SessionID: SessionID,
@@ -130,13 +136,12 @@ func (c *Client) Variation(index int, messageId string, hash string) error {
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
url := "https://discord.com/api/v9/interactions"
var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json").
SetBody(interactionsReq).
SetErrorResult(&res).
Post(url)
Post(c.apiURL)
if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
}

View File

@@ -6,9 +6,9 @@ import (
"chatplus/store"
"chatplus/store/model"
"fmt"
"github.com/go-redis/redis/v8"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
@@ -16,13 +16,16 @@ import (
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.MjConfigs {
if config.Enabled == false {
@@ -33,9 +36,9 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, queue, 4, 600, db, client)
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
if err != nil {
continue
}
@@ -54,13 +57,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
}
return &ServicePool{
taskQueue: queue,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
if err != nil {
continue
}
client := p.Clients.Get(userId)
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() {
go func() {
var items []model.MidJourneyJob
@@ -71,15 +93,21 @@ func (p *ServicePool) DownloadImages() {
}
// download images
for _, item := range items {
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(item.OrgURL, true)
for _, v := range items {
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
if err != nil {
logger.Error("error with download image: ", err)
continue
}
item.ImgURL = imgURL
p.db.Updates(&item)
v.ImgURL = imgURL
p.db.Updates(&v)
client := p.Clients.Get(uint(v.UserId))
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)

View File

@@ -15,6 +15,7 @@ type Service struct {
name string // service name
client *Client // MJ client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
@@ -22,11 +23,12 @@ type Service struct {
taskTimeout int64
}
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
return &Service{
name: name,
db: db,
taskQueue: queue,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
client: client,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
@@ -53,9 +55,10 @@ func (s *Service) Run() {
}
// if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
s.taskQueue.RPush(task)
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
s.notifyQueue.RPush(task.UserId)
time.Sleep(time.Second)
continue
}
@@ -77,6 +80,7 @@ func (s *Service) Run() {
logger.Error("绘画任务执行失败:", err)
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
s.notifyQueue.RPush(task.UserId)
// restore img_call quota
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
continue
@@ -134,6 +138,10 @@ func (s *Service) Notify(data CBReq) {
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL
if s.client.Config.UseCDN {
job.UseProxy = true
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN)
}
res = s.db.Updates(&job)
if res.Error != nil {
@@ -146,4 +154,6 @@ func (s *Service) Notify(data CBReq) {
atomic.AddInt32(&s.handledTaskNum, -1)
}
s.notifyQueue.RPush(job.UserId)
}