mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 10:13:44 +08:00
feat: support CDN reverse proxy for MidJourney and OpenAI API
This commit is contained in:
@@ -15,6 +15,7 @@ type Service struct {
|
||||
name string // service name
|
||||
client *Client // MJ client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
maxHandleTaskNum int32 // max task number current service can handle
|
||||
handledTaskNum int32 // already handled task number
|
||||
@@ -22,11 +23,12 @@ type Service struct {
|
||||
taskTimeout int64
|
||||
}
|
||||
|
||||
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
||||
return &Service{
|
||||
name: name,
|
||||
db: db,
|
||||
taskQueue: queue,
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
client: client,
|
||||
taskTimeout: timeout,
|
||||
maxHandleTaskNum: maxTaskNum,
|
||||
@@ -53,9 +55,10 @@ func (s *Service) Run() {
|
||||
}
|
||||
|
||||
// if it's reference message, check if it's this channel's message
|
||||
if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
|
||||
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
|
||||
s.taskQueue.RPush(task)
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
@@ -77,6 +80,7 @@ func (s *Service) Run() {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
// update the task progress
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
// restore img_call quota
|
||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
||||
continue
|
||||
@@ -134,6 +138,10 @@ func (s *Service) Notify(data CBReq) {
|
||||
job.Prompt = data.Prompt
|
||||
job.Hash = data.Image.Hash
|
||||
job.OrgURL = data.Image.URL
|
||||
if s.client.Config.UseCDN {
|
||||
job.UseProxy = true
|
||||
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN)
|
||||
}
|
||||
|
||||
res = s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
@@ -146,4 +154,6 @@ func (s *Service) Notify(data CBReq) {
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user