feat: support CDN reverse proxy for MidJourney and OpenAI API

This commit is contained in:
RockYang 2023-12-22 17:25:31 +08:00
parent de512a5ea2
commit 3ab930a107
19 changed files with 218 additions and 87 deletions

View File

@ -154,6 +154,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/role/list" || c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/mj/jobs" || c.Request.URL.Path == "/api/mj/jobs" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/jobs" || c.Request.URL.Path == "/api/sd/jobs" ||
c.Request.URL.Path == "/api/upload" || c.Request.URL.Path == "/api/upload" ||

View File

@ -34,11 +34,15 @@ type ChatPlusApiConfig struct {
} }
type MidJourneyConfig struct { type MidJourneyConfig struct {
Enabled bool Enabled bool
UserToken string UserToken string
BotToken string BotToken string
GuildId string // Server ID GuildId string // Server ID
ChanelId string // Chanel ID ChanelId string // Chanel ID
UseCDN bool
DiscordAPI string
DiscordCDN string
DiscordGateway string
} }
type StableDiffusionConfig struct { type StableDiffusionConfig struct {

View File

@ -6,7 +6,7 @@ import (
) )
type MKey interface { type MKey interface {
string | int string | int | uint
} }
type MValue interface { type MValue interface {
*WsClient | *ChatSession | context.CancelFunc | []interface{} *WsClient | *ChatSession | context.CancelFunc | []interface{}

View File

@ -6,7 +6,6 @@ require (
github.com/BurntSushi/toml v1.1.0 github.com/BurntSushi/toml v1.1.0
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
github.com/bwmarrin/discordgo v0.27.1
github.com/eatmoreapple/openwechat v1.2.1 github.com/eatmoreapple/openwechat v1.2.1
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
@ -26,6 +25,8 @@ require (
require github.com/xxl-job/xxl-job-executor-go v1.2.0 require github.com/xxl-job/xxl-job-executor-go v1.2.0
require github.com/bg5t/mydiscordgo v0.28.1
require ( require (
github.com/andybalholm/brotli v1.0.4 // indirect github.com/andybalholm/brotli v1.0.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect

View File

@ -7,8 +7,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY= github.com/bg5t/mydiscordgo v0.28.1 h1:mVH0ZWstVdJffCi/EXJAYQDtXwIKAJYVXLmECu1hEK8=
github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/bg5t/mydiscordgo v0.28.1/go.mod h1:n3aba73N18k1DzM0t0mGE8rwW3Z+vwTvI8pcsBgxN/8=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=

View File

@ -442,7 +442,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
} else { } else {
client = http.DefaultClient client = http.DefaultClient
} }
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model) logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model)
switch platform { switch platform {
case types.Azure: case types.Azure:
request.Header.Set("api-key", *apiKey) request.Header.Set("api-key", *apiKey)
@ -452,7 +452,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
if err != nil { if err != nil {
return nil, err return nil, err
} }
logger.Info(token)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break break
case types.Baidu: case types.Baidu:

View File

@ -13,7 +13,9 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
"strings" "strings"
"time" "time"
) )
@ -58,6 +60,27 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) { func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct { var data struct {
@ -147,6 +170,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
UserId: userId, UserId: userId,
}) })
client := h.pool.Clients.Get(uint(job.UserId))
_ = client.Send([]byte("Task Updated"))
// update user's img calls // update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
resp.SUCCESS(c) resp.SUCCESS(c)
@ -205,6 +231,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
}) })
client := h.pool.Clients.Get(uint(job.UserId))
_ = client.Send([]byte("Task Updated"))
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@ -226,6 +256,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
job := model.MidJourneyJob{ job := model.MidJourneyJob{
Type: types.TaskVariation.String(), Type: types.TaskVariation.String(),
ChannelId: data.ChannelId,
ReferenceId: data.MessageId, ReferenceId: data.MessageId,
UserId: userId, UserId: userId,
TaskId: data.TaskId, TaskId: data.TaskId,
@ -250,6 +281,9 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
}) })
client := h.pool.Clients.Get(uint(job.UserId))
_ = client.Send([]byte("Task Updated"))
// update user's img calls // update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
resp.SUCCESS(c) resp.SUCCESS(c)
@ -320,6 +354,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
func (h *MidJourneyHandler) Remove(c *gin.Context) { func (h *MidJourneyHandler) Remove(c *gin.Context) {
var data struct { var data struct {
Id uint `json:"id"` Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"` ImgURL string `json:"img_url"`
} }
if err := c.ShouldBindJSON(&data); err != nil { if err := c.ShouldBindJSON(&data); err != nil {
@ -340,5 +375,8 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(data.UserId)
_ = client.Send([]byte("Task Updated"))
resp.SUCCESS(c) resp.SUCCESS(c)
} }

View File

@ -17,6 +17,7 @@ import (
"chatplus/store" "chatplus/store"
"context" "context"
"embed" "embed"
"github.com/go-redis/redis/v8"
"io" "io"
"log" "log"
"os" "os"
@ -25,8 +26,6 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/go-redis/redis/v8"
"github.com/lionsoul2014/ip2region/binding/golang/xdb" "github.com/lionsoul2014/ip2region/binding/golang/xdb"
"go.uber.org/fx" "go.uber.org/fx"
"gorm.io/gorm" "gorm.io/gorm"
@ -168,6 +167,7 @@ func main() {
fx.Invoke(func(pool *mj.ServicePool) { fx.Invoke(func(pool *mj.ServicePool) {
if pool.HasAvailableService() { if pool.HasAvailableService() {
pool.DownloadImages() pool.DownloadImages()
pool.CheckTaskNotify()
} }
}), }),
@ -234,6 +234,7 @@ func main() {
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
group := s.Engine.Group("/api/mj/") group := s.Engine.Group("/api/mj/")
group.Any("client", h.Client)
group.POST("image", h.Image) group.POST("image", h.Image)
group.POST("upscale", h.Upscale) group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation) group.POST("variation", h.Variation)

View File

@ -4,7 +4,7 @@ import (
"chatplus/core/types" "chatplus/core/types"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"chatplus/utils" "chatplus/utils"
"github.com/bwmarrin/discordgo" discordgo "github.com/bg5t/mydiscordgo"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"net/http" "net/http"
"net/url" "net/url"
@ -17,35 +17,48 @@ import (
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
type Bot struct { type Bot struct {
config *types.MidJourneyConfig config types.MidJourneyConfig
bot *discordgo.Session bot *discordgo.Session
name string name string
service *Service service *Service
} }
func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) { func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.BotToken) bot, err := discordgo.New("Bot " + config.BotToken)
logger.Info(config.BotToken)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return nil, err return nil, err
} }
if proxy != "" { // use CDN reverse proxy
proxy, _ := url.Parse(proxy) if config.UseCDN {
discord.Client = &http.Client{ discordgo.SetEndpointDiscord(config.DiscordAPI)
Transport: &http.Transport{ discordgo.SetEndpointCDN(config.DiscordCDN)
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
bot.MjGateway = config.DiscordGateway + "/"
} else { // use proxy
discordgo.SetEndpointDiscord("https://discord.com")
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
discordgo.SetEndpointStatus("https://discord.com/api/v2/")
bot.MjGateway = "wss://gateway.discord.gg"
if proxy != "" {
proxy, _ := url.Parse(proxy)
bot.Client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
bot.Dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxy), Proxy: http.ProxyURL(proxy),
}, }
}
discord.Dialer = &websocket.Dialer{
Proxy: http.ProxyURL(proxy),
} }
} }
return &Bot{ return &Bot{
config: config, config: config,
bot: discord, bot: bot,
name: name, name: name,
service: service, service: service,
}, nil }, nil

View File

@ -12,24 +12,32 @@ import (
type Client struct { type Client struct {
client *req.Client client *req.Client
config types.MidJourneyConfig Config types.MidJourneyConfig
apiURL string
} }
func NewClient(config types.MidJourneyConfig, proxy string) *Client { func NewClient(config types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second) client := req.C().SetTimeout(10 * time.Second)
var apiURL string
// set proxy URL // set proxy URL
if proxy != "" { if config.UseCDN {
client.SetProxyURL(proxy) apiURL = config.DiscordAPI + "/api/v9/interactions"
} else {
apiURL = "https://discord.com/api/v9/interactions"
if proxy != "" {
client.SetProxyURL(proxy)
}
} }
return &Client{client: client, config: config}
return &Client{client: client, Config: config, apiURL: apiURL}
} }
func (c *Client) Imagine(prompt string) error { func (c *Client) Imagine(prompt string) error {
interactionsReq := &InteractionsRequest{ interactionsReq := &InteractionsRequest{
Type: 2, Type: 2,
ApplicationID: ApplicationID, ApplicationID: ApplicationID,
GuildID: c.config.GuildId, GuildID: c.Config.GuildId,
ChannelID: c.config.ChanelId, ChannelID: c.Config.ChanelId,
SessionID: SessionID, SessionID: SessionID,
Data: map[string]any{ Data: map[string]any{
"version": "1166847114203123795", "version": "1166847114203123795",
@ -67,11 +75,10 @@ func (c *Client) Imagine(prompt string) error {
}, },
} }
url := "https://discord.com/api/v9/interactions" r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
SetHeader("Content-Type", "application/json"). SetHeader("Content-Type", "application/json").
SetBody(interactionsReq). SetBody(interactionsReq).
Post(url) Post(c.apiURL)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %w%v", err, r.Err) return fmt.Errorf("error with http request: %w%v", err, r.Err)
@ -86,8 +93,8 @@ func (c *Client) Upscale(index int, messageId string, hash string) error {
interactionsReq := &InteractionsRequest{ interactionsReq := &InteractionsRequest{
Type: 3, Type: 3,
ApplicationID: ApplicationID, ApplicationID: ApplicationID,
GuildID: c.config.GuildId, GuildID: c.Config.GuildId,
ChannelID: c.config.ChanelId, ChannelID: c.Config.ChanelId,
MessageFlags: &flags, MessageFlags: &flags,
MessageID: &messageId, MessageID: &messageId,
SessionID: SessionID, SessionID: SessionID,
@ -98,13 +105,12 @@ func (c *Client) Upscale(index int, messageId string, hash string) error {
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
} }
url := "https://discord.com/api/v9/interactions"
var res InteractionsResult var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json"). SetHeader("Content-Type", "application/json").
SetBody(interactionsReq). SetBody(interactionsReq).
SetErrorResult(&res). SetErrorResult(&res).
Post(url) Post(c.apiURL)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
} }
@ -118,8 +124,8 @@ func (c *Client) Variation(index int, messageId string, hash string) error {
interactionsReq := &InteractionsRequest{ interactionsReq := &InteractionsRequest{
Type: 3, Type: 3,
ApplicationID: ApplicationID, ApplicationID: ApplicationID,
GuildID: c.config.GuildId, GuildID: c.Config.GuildId,
ChannelID: c.config.ChanelId, ChannelID: c.Config.ChanelId,
MessageFlags: &flags, MessageFlags: &flags,
MessageID: &messageId, MessageID: &messageId,
SessionID: SessionID, SessionID: SessionID,
@ -130,13 +136,12 @@ func (c *Client) Variation(index int, messageId string, hash string) error {
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
} }
url := "https://discord.com/api/v9/interactions"
var res InteractionsResult var res InteractionsResult
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
SetHeader("Content-Type", "application/json"). SetHeader("Content-Type", "application/json").
SetBody(interactionsReq). SetBody(interactionsReq).
SetErrorResult(&res). SetErrorResult(&res).
Post(url) Post(c.apiURL)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
} }

View File

@ -6,9 +6,9 @@ import (
"chatplus/store" "chatplus/store"
"chatplus/store/model" "chatplus/store/model"
"fmt" "fmt"
"github.com/go-redis/redis/v8"
"time" "time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -16,13 +16,16 @@ import (
type ServicePool struct { type ServicePool struct {
services []*Service services []*Service
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
uploaderManager *oss.UploaderManager uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
} }
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0) services := make([]*Service, 0)
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
// create mj client and service // create mj client and service
for k, config := range appConfig.MjConfigs { for k, config := range appConfig.MjConfigs {
if config.Enabled == false { if config.Enabled == false {
@ -33,9 +36,9 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
name := fmt.Sprintf("MjService-%d", k) name := fmt.Sprintf("MjService-%d", k)
// create mj service // create mj service
service := NewService(name, queue, 4, 600, db, client) service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
botName := fmt.Sprintf("MjBot-%d", k) botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service) bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
if err != nil { if err != nil {
continue continue
} }
@ -54,13 +57,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
} }
return &ServicePool{ return &ServicePool{
taskQueue: queue, taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services, services: services,
uploaderManager: manager, uploaderManager: manager,
db: db, db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
} }
} }
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var userId uint
err := p.notifyQueue.LPop(&userId)
if err != nil {
continue
}
client := p.Clients.Get(userId)
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() { func (p *ServicePool) DownloadImages() {
go func() { go func() {
var items []model.MidJourneyJob var items []model.MidJourneyJob
@ -71,15 +93,21 @@ func (p *ServicePool) DownloadImages() {
} }
// download images // download images
for _, item := range items { for _, v := range items {
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(item.OrgURL, true) imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
if err != nil { if err != nil {
logger.Error("error with download image: ", err) logger.Error("error with download image: ", err)
continue continue
} }
item.ImgURL = imgURL v.ImgURL = imgURL
p.db.Updates(&item) p.db.Updates(&v)
client := p.Clients.Get(uint(v.UserId))
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
}
} }
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)

View File

@ -15,6 +15,7 @@ type Service struct {
name string // service name name string // service name
client *Client // MJ client client *Client // MJ client
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number handledTaskNum int32 // already handled task number
@ -22,11 +23,12 @@ type Service struct {
taskTimeout int64 taskTimeout int64
} }
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
return &Service{ return &Service{
name: name, name: name,
db: db, db: db,
taskQueue: queue, taskQueue: taskQueue,
notifyQueue: notifyQueue,
client: client, client: client,
taskTimeout: timeout, taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum, maxHandleTaskNum: maxTaskNum,
@ -53,9 +55,10 @@ func (s *Service) Run() {
} }
// if it's reference message, check if it's this channel's message // if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId { if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
s.taskQueue.RPush(task) s.taskQueue.RPush(task)
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
s.notifyQueue.RPush(task.UserId)
time.Sleep(time.Second) time.Sleep(time.Second)
continue continue
} }
@ -77,6 +80,7 @@ func (s *Service) Run() {
logger.Error("绘画任务执行失败:", err) logger.Error("绘画任务执行失败:", err)
// update the task progress // update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
s.notifyQueue.RPush(task.UserId)
// restore img_call quota // restore img_call quota
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
continue continue
@ -134,6 +138,10 @@ func (s *Service) Notify(data CBReq) {
job.Prompt = data.Prompt job.Prompt = data.Prompt
job.Hash = data.Image.Hash job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL job.OrgURL = data.Image.URL
if s.client.Config.UseCDN {
job.UseProxy = true
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN)
}
res = s.db.Updates(&job) res = s.db.Updates(&job)
if res.Error != nil { if res.Error != nil {
@ -146,4 +154,6 @@ func (s *Service) Notify(data CBReq) {
atomic.AddInt32(&s.handledTaskNum, -1) atomic.AddInt32(&s.handledTaskNum, -1)
} }
s.notifyQueue.RPush(job.UserId)
} }

View File

@ -3,6 +3,7 @@ package model
type Function struct { type Function struct {
Id uint `gorm:"primarykey;column:id"` Id uint `gorm:"primarykey;column:id"`
Name string Name string
Label string
Description string Description string
Parameters string Parameters string
Required string Required string

View File

@ -15,6 +15,7 @@ type MidJourneyJob struct {
Hash string // message hash Hash string // message hash
Progress int Progress int
Prompt string Prompt string
UseProxy bool // 是否使用反代加载图片
CreatedAt time.Time CreatedAt time.Time
} }

View File

@ -15,5 +15,6 @@ type MidJourneyJob struct {
Hash string `json:"hash"` Hash string `json:"hash"`
Progress int `json:"progress"` Progress int `json:"progress"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
UseProxy bool `json:"use_proxy"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
} }

View File

@ -18,4 +18,7 @@ ALTER TABLE `chatgpt_functions` ADD UNIQUE(`name`);
ALTER TABLE `chatgpt_functions` ADD `enabled` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '是否启用' AFTER `action`; ALTER TABLE `chatgpt_functions` ADD `enabled` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '是否启用' AFTER `action`;
ALTER TABLE `chatgpt_functions` ADD `lebal` VARCHAR(30) NULL COMMENT '函数标签' AFTER `name`; ALTER TABLE `chatgpt_functions` ADD `label` VARCHAR(30) NULL COMMENT '函数标签' AFTER `name`;
ALTER TABLE `chatgpt_mj_jobs` ADD `use_proxy` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '是否使用反代' AFTER `progress`;
ALTER TABLE `chatgpt_mj_jobs` CHANGE `img_url` `img_url` VARCHAR(400) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT '图片URL';

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -359,7 +359,7 @@
<template #default="scope"> <template #default="scope">
<div class="job-item"> <div class="job-item">
<el-image <el-image
:src="scope.item.type === 'upscale' ? scope.item['img_url'] + '?imageView2/1/w/480/h/600/q/75' : scope.item['img_url'] + '?imageView2/1/w/480/h/480/q/75'" :src="scope.item['thumb_url']"
:class="scope.item.type === 'upscale' ? 'upscale' : ''" :zoom-rate="1.2" :class="scope.item.type === 'upscale' ? 'upscale' : ''" :zoom-rate="1.2"
:preview-src-list="[scope.item['img_url']]" fit="cover" :initial-index="scope.index" :preview-src-list="[scope.item['img_url']]" fit="cover" :initial-index="scope.index"
loading="lazy" v-if="scope.item.progress > 0"> loading="lazy" v-if="scope.item.progress > 0">
@ -477,7 +477,8 @@ const rates = [
{css: "size9-16", value: "9:16", text: "9:16", img: "/images/mj/rate_9_16.png"}, {css: "size9-16", value: "9:16", text: "9:16", img: "/images/mj/rate_9_16.png"},
] ]
const models = [ const models = [
{text: "最新模式MJ-5.2", value: " --v 5.2", img: "/images/mj/mj-v5.2.png"}, {text: "写实模式MJ-6.0", value: " --v 6", img: "/images/mj/mj-v6.png"},
{text: "优质模式MJ-5.2", value: " --v 5.2", img: "/images/mj/mj-v5.2.png"},
{text: "优质模式MJ-5.1", value: " --v 5.1", img: "/images/mj/mj-v5.1.jpg"}, {text: "优质模式MJ-5.1", value: " --v 5.1", img: "/images/mj/mj-v5.1.jpg"},
{text: "虚幻模式MJ-5", value: " --v 5", img: "/images/mj/mj-v5.jpg"}, {text: "虚幻模式MJ-5", value: " --v 5", img: "/images/mj/mj-v5.jpg"},
{text: "真实模式MJ-4", value: " --v 4", img: "/images/mj/mj-v4.jpg"}, {text: "真实模式MJ-4", value: " --v 4", img: "/images/mj/mj-v4.jpg"},
@ -489,6 +490,10 @@ const models = [
] ]
const options = [ const options = [
{
value: 0,
label: '默认'
},
{ {
value: 0.25, value: 0.25,
label: '普通' label: '普通'
@ -515,7 +520,7 @@ const params = ref({
prompt: "", prompt: "",
neg_prompt: "", neg_prompt: "",
tile: false, tile: false,
quality: 0.5 quality: 0
}) })
const activeName = ref('图生图') const activeName = ref('图生图')
@ -527,6 +532,7 @@ const router = useRouter()
const socket = ref(null) const socket = ref(null)
const imgCalls = ref(0) const imgCalls = ref(0)
const loading = ref(false) const loading = ref(false)
const userId = ref(0)
const rewritePrompt = () => { const rewritePrompt = () => {
loading.value = true loading.value = true
@ -550,12 +556,40 @@ const translatePrompt = () => {
}) })
} }
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/mj/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
fetchRunningJobs(userId.value)
fetchFinishJobs(userId.value)
}
});
_socket.addEventListener('close', () => {
connect()
});
}
onMounted(() => { onMounted(() => {
checkSession().then(user => { checkSession().then(user => {
imgCalls.value = user['img_calls'] imgCalls.value = user['img_calls']
userId.value = user.id
fetchRunningJobs(user.id) fetchRunningJobs(userId.value)
fetchFinishJobs(user.id) fetchFinishJobs(userId.value)
connect()
}).catch(() => { }).catch(() => {
router.push('/login') router.push('/login')
@ -588,39 +622,29 @@ const fetchRunningJobs = (userId) => {
_jobs.push(jobs[i]) _jobs.push(jobs[i])
} }
runningJobs.value = _jobs runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 1000)
}).catch(e => { }).catch(e => {
ElMessage.error("获取任务失败:" + e.message) ElMessage.error("获取任务失败:" + e.message)
setTimeout(() => fetchRunningJobs(userId), 5000)
}) })
} }
const fetchFinishJobs = (userId) => { const fetchFinishJobs = (userId) => {
// //
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => { httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
if (finishedJobs.value.length === 0) { const jobs = res.data
finishedJobs.value = res.data for (let i = 0; i < jobs.length; i++) {
return if (jobs[i]['use_proxy']) {
} jobs[i]['thumb_url'] = jobs[i]['img_url'] + '?x-oss-process=image/quality,q_60&format=webp'
} else {
// check if the img url is changed if (jobs[i].type === 'upscale') {
const list = res.data jobs[i]['thumb_url'] = jobs[i]['img_url'] + '?imageView2/1/w/480/h/600/q/75'
let changed = false } else {
for (let i = 0; i < list.length; i++) { jobs[i]['thumb_url'] = jobs[i]['img_url'] + '?imageView2/1/w/480/h/480/q/75'
if (list[i]["img_url"] !== finishedJobs.value[i]["img_url"]) { }
changed = true
break
} }
} }
if (changed) { finishedJobs.value = jobs
finishedJobs.value = list
}
setTimeout(() => fetchFinishJobs(userId), 1000)
}).catch(e => { }).catch(e => {
ElMessage.error("获取任务失败:" + e.message) ElMessage.error("获取任务失败:" + e.message)
setTimeout(() => fetchFinishJobs(userId), 5000)
}) })
} }
@ -710,7 +734,7 @@ const removeImage = (item) => {
type: 'warning', type: 'warning',
} }
).then(() => { ).then(() => {
httpPost("/api/mj/remove", {id: item.id, img_url: item.img_url}).then(() => { httpPost("/api/mj/remove", {id: item.id, img_url: item.img_url, user_id: userId.value}).then(() => {
ElMessage.success("任务删除成功") ElMessage.success("任务删除成功")
}).catch(e => { }).catch(e => {
ElMessage.error("任务删除失败:" + e.message) ElMessage.error("任务删除失败:" + e.message)

View File

@ -612,8 +612,9 @@ onMounted(() => {
// //
const fetchFinishJobs = (userId) => { const fetchFinishJobs = (userId) => {
httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => { httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
if (finishedJobs.value.length === 0) { if (finishedJobs.value.length === 0 || res.data.length > finishedJobs.value.length) {
finishedJobs.value = res.data finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 1000)
return return
} }