mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	feat: support CDN reverse proxy for MidJourney and OpenAI API
This commit is contained in:
		@@ -154,6 +154,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
 | 
				
			|||||||
			c.Request.URL.Path == "/api/chat/detail" ||
 | 
								c.Request.URL.Path == "/api/chat/detail" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/role/list" ||
 | 
								c.Request.URL.Path == "/api/role/list" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/mj/jobs" ||
 | 
								c.Request.URL.Path == "/api/mj/jobs" ||
 | 
				
			||||||
 | 
								c.Request.URL.Path == "/api/mj/client" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/invite/hits" ||
 | 
								c.Request.URL.Path == "/api/invite/hits" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/sd/jobs" ||
 | 
								c.Request.URL.Path == "/api/sd/jobs" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/upload" ||
 | 
								c.Request.URL.Path == "/api/upload" ||
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -34,11 +34,15 @@ type ChatPlusApiConfig struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type MidJourneyConfig struct {
 | 
					type MidJourneyConfig struct {
 | 
				
			||||||
	Enabled   bool
 | 
						Enabled        bool
 | 
				
			||||||
	UserToken string
 | 
						UserToken      string
 | 
				
			||||||
	BotToken  string
 | 
						BotToken       string
 | 
				
			||||||
	GuildId   string // Server ID
 | 
						GuildId        string // Server ID
 | 
				
			||||||
	ChanelId  string // Chanel ID
 | 
						ChanelId       string // Chanel ID
 | 
				
			||||||
 | 
						UseCDN         bool
 | 
				
			||||||
 | 
						DiscordAPI     string
 | 
				
			||||||
 | 
						DiscordCDN     string
 | 
				
			||||||
 | 
						DiscordGateway string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type StableDiffusionConfig struct {
 | 
					type StableDiffusionConfig struct {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,7 +6,7 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type MKey interface {
 | 
					type MKey interface {
 | 
				
			||||||
	string | int
 | 
						string | int | uint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
type MValue interface {
 | 
					type MValue interface {
 | 
				
			||||||
	*WsClient | *ChatSession | context.CancelFunc | []interface{}
 | 
						*WsClient | *ChatSession | context.CancelFunc | []interface{}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,7 +6,6 @@ require (
 | 
				
			|||||||
	github.com/BurntSushi/toml v1.1.0
 | 
						github.com/BurntSushi/toml v1.1.0
 | 
				
			||||||
	github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
 | 
						github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
 | 
				
			||||||
	github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
 | 
						github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
 | 
				
			||||||
	github.com/bwmarrin/discordgo v0.27.1
 | 
					 | 
				
			||||||
	github.com/eatmoreapple/openwechat v1.2.1
 | 
						github.com/eatmoreapple/openwechat v1.2.1
 | 
				
			||||||
	github.com/gin-gonic/gin v1.9.1
 | 
						github.com/gin-gonic/gin v1.9.1
 | 
				
			||||||
	github.com/go-redis/redis/v8 v8.11.5
 | 
						github.com/go-redis/redis/v8 v8.11.5
 | 
				
			||||||
@@ -26,6 +25,8 @@ require (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
 | 
					require github.com/xxl-job/xxl-job-executor-go v1.2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					require github.com/bg5t/mydiscordgo v0.28.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	github.com/andybalholm/brotli v1.0.4 // indirect
 | 
						github.com/andybalholm/brotli v1.0.4 // indirect
 | 
				
			||||||
	github.com/bytedance/sonic v1.9.1 // indirect
 | 
						github.com/bytedance/sonic v1.9.1 // indirect
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,8 +7,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
 | 
				
			|||||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
 | 
					github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
 | 
				
			||||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
 | 
					github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
 | 
				
			||||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
 | 
					github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
 | 
				
			||||||
github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY=
 | 
					github.com/bg5t/mydiscordgo v0.28.1 h1:mVH0ZWstVdJffCi/EXJAYQDtXwIKAJYVXLmECu1hEK8=
 | 
				
			||||||
github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
 | 
					github.com/bg5t/mydiscordgo v0.28.1/go.mod h1:n3aba73N18k1DzM0t0mGE8rwW3Z+vwTvI8pcsBgxN/8=
 | 
				
			||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 | 
					github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 | 
				
			||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 | 
					github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 | 
				
			||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 | 
					github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -442,7 +442,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
 | 
				
			|||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		client = http.DefaultClient
 | 
							client = http.DefaultClient
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
 | 
						logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model)
 | 
				
			||||||
	switch platform {
 | 
						switch platform {
 | 
				
			||||||
	case types.Azure:
 | 
						case types.Azure:
 | 
				
			||||||
		request.Header.Set("api-key", *apiKey)
 | 
							request.Header.Set("api-key", *apiKey)
 | 
				
			||||||
@@ -452,7 +452,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		logger.Info(token)
 | 
					 | 
				
			||||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
 | 
							request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
 | 
				
			||||||
		break
 | 
							break
 | 
				
			||||||
	case types.Baidu:
 | 
						case types.Baidu:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,7 +13,9 @@ import (
 | 
				
			|||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"github.com/gorilla/websocket"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -58,6 +60,27 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Client WebSocket 客户端,用于通知任务状态变更
 | 
				
			||||||
 | 
					func (h *MidJourneyHandler) Client(c *gin.Context) {
 | 
				
			||||||
 | 
						ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logger.Error(err)
 | 
				
			||||||
 | 
							c.Abort()
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						userId := h.GetInt(c, "user_id", 0)
 | 
				
			||||||
 | 
						if userId == 0 {
 | 
				
			||||||
 | 
							logger.Info("Invalid user ID")
 | 
				
			||||||
 | 
							c.Abort()
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						client := types.NewWsClient(ws)
 | 
				
			||||||
 | 
						h.pool.Clients.Put(uint(userId), client)
 | 
				
			||||||
 | 
						logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Image 创建一个绘画任务
 | 
					// Image 创建一个绘画任务
 | 
				
			||||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
					func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
				
			||||||
	var data struct {
 | 
						var data struct {
 | 
				
			||||||
@@ -147,6 +170,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
				
			|||||||
		UserId:    userId,
 | 
							UserId:    userId,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						client := h.pool.Clients.Get(uint(job.UserId))
 | 
				
			||||||
 | 
						_ = client.Send([]byte("Task Updated"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// update user's img calls
 | 
						// update user's img calls
 | 
				
			||||||
	h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
						h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
@@ -205,6 +231,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
				
			|||||||
		MessageId:   data.MessageId,
 | 
							MessageId:   data.MessageId,
 | 
				
			||||||
		MessageHash: data.MessageHash,
 | 
							MessageHash: data.MessageHash,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						client := h.pool.Clients.Get(uint(job.UserId))
 | 
				
			||||||
 | 
						_ = client.Send([]byte("Task Updated"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -226,6 +256,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	job := model.MidJourneyJob{
 | 
						job := model.MidJourneyJob{
 | 
				
			||||||
		Type:        types.TaskVariation.String(),
 | 
							Type:        types.TaskVariation.String(),
 | 
				
			||||||
 | 
							ChannelId:   data.ChannelId,
 | 
				
			||||||
		ReferenceId: data.MessageId,
 | 
							ReferenceId: data.MessageId,
 | 
				
			||||||
		UserId:      userId,
 | 
							UserId:      userId,
 | 
				
			||||||
		TaskId:      data.TaskId,
 | 
							TaskId:      data.TaskId,
 | 
				
			||||||
@@ -250,6 +281,9 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
				
			|||||||
		MessageHash: data.MessageHash,
 | 
							MessageHash: data.MessageHash,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						client := h.pool.Clients.Get(uint(job.UserId))
 | 
				
			||||||
 | 
						_ = client.Send([]byte("Task Updated"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// update user's img calls
 | 
						// update user's img calls
 | 
				
			||||||
	h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
						h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
@@ -320,6 +354,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
				
			|||||||
func (h *MidJourneyHandler) Remove(c *gin.Context) {
 | 
					func (h *MidJourneyHandler) Remove(c *gin.Context) {
 | 
				
			||||||
	var data struct {
 | 
						var data struct {
 | 
				
			||||||
		Id     uint   `json:"id"`
 | 
							Id     uint   `json:"id"`
 | 
				
			||||||
 | 
							UserId uint   `json:"user_id"`
 | 
				
			||||||
		ImgURL string `json:"img_url"`
 | 
							ImgURL string `json:"img_url"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
						if err := c.ShouldBindJSON(&data); err != nil {
 | 
				
			||||||
@@ -340,5 +375,8 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
 | 
				
			|||||||
		logger.Error("remove image failed: ", err)
 | 
							logger.Error("remove image failed: ", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						client := h.pool.Clients.Get(data.UserId)
 | 
				
			||||||
 | 
						_ = client.Send([]byte("Task Updated"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,6 +17,7 @@ import (
 | 
				
			|||||||
	"chatplus/store"
 | 
						"chatplus/store"
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"embed"
 | 
						"embed"
 | 
				
			||||||
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
@@ -25,8 +26,6 @@ import (
 | 
				
			|||||||
	"syscall"
 | 
						"syscall"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/go-redis/redis/v8"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/lionsoul2014/ip2region/binding/golang/xdb"
 | 
						"github.com/lionsoul2014/ip2region/binding/golang/xdb"
 | 
				
			||||||
	"go.uber.org/fx"
 | 
						"go.uber.org/fx"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
@@ -168,6 +167,7 @@ func main() {
 | 
				
			|||||||
		fx.Invoke(func(pool *mj.ServicePool) {
 | 
							fx.Invoke(func(pool *mj.ServicePool) {
 | 
				
			||||||
			if pool.HasAvailableService() {
 | 
								if pool.HasAvailableService() {
 | 
				
			||||||
				pool.DownloadImages()
 | 
									pool.DownloadImages()
 | 
				
			||||||
 | 
									pool.CheckTaskNotify()
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -234,6 +234,7 @@ func main() {
 | 
				
			|||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/mj/")
 | 
								group := s.Engine.Group("/api/mj/")
 | 
				
			||||||
 | 
								group.Any("client", h.Client)
 | 
				
			||||||
			group.POST("image", h.Image)
 | 
								group.POST("image", h.Image)
 | 
				
			||||||
			group.POST("upscale", h.Upscale)
 | 
								group.POST("upscale", h.Upscale)
 | 
				
			||||||
			group.POST("variation", h.Variation)
 | 
								group.POST("variation", h.Variation)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,7 +4,7 @@ import (
 | 
				
			|||||||
	"chatplus/core/types"
 | 
						"chatplus/core/types"
 | 
				
			||||||
	logger2 "chatplus/logger"
 | 
						logger2 "chatplus/logger"
 | 
				
			||||||
	"chatplus/utils"
 | 
						"chatplus/utils"
 | 
				
			||||||
	"github.com/bwmarrin/discordgo"
 | 
						discordgo "github.com/bg5t/mydiscordgo"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
						"github.com/gorilla/websocket"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
@@ -17,35 +17,48 @@ import (
 | 
				
			|||||||
var logger = logger2.GetLogger()
 | 
					var logger = logger2.GetLogger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Bot struct {
 | 
					type Bot struct {
 | 
				
			||||||
	config  *types.MidJourneyConfig
 | 
						config  types.MidJourneyConfig
 | 
				
			||||||
	bot     *discordgo.Session
 | 
						bot     *discordgo.Session
 | 
				
			||||||
	name    string
 | 
						name    string
 | 
				
			||||||
	service *Service
 | 
						service *Service
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) {
 | 
					func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
 | 
				
			||||||
	discord, err := discordgo.New("Bot " + config.BotToken)
 | 
						bot, err := discordgo.New("Bot " + config.BotToken)
 | 
				
			||||||
	logger.Info(config.BotToken)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Error(err)
 | 
							logger.Error(err)
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if proxy != "" {
 | 
						// use CDN reverse proxy
 | 
				
			||||||
		proxy, _ := url.Parse(proxy)
 | 
						if config.UseCDN {
 | 
				
			||||||
		discord.Client = &http.Client{
 | 
							discordgo.SetEndpointDiscord(config.DiscordAPI)
 | 
				
			||||||
			Transport: &http.Transport{
 | 
							discordgo.SetEndpointCDN(config.DiscordCDN)
 | 
				
			||||||
 | 
							discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
 | 
				
			||||||
 | 
							bot.MjGateway = config.DiscordGateway + "/"
 | 
				
			||||||
 | 
						} else { // use proxy
 | 
				
			||||||
 | 
							discordgo.SetEndpointDiscord("https://discord.com")
 | 
				
			||||||
 | 
							discordgo.SetEndpointCDN("https://cdn.discordapp.com")
 | 
				
			||||||
 | 
							discordgo.SetEndpointStatus("https://discord.com/api/v2/")
 | 
				
			||||||
 | 
							bot.MjGateway = "wss://gateway.discord.gg"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if proxy != "" {
 | 
				
			||||||
 | 
								proxy, _ := url.Parse(proxy)
 | 
				
			||||||
 | 
								bot.Client = &http.Client{
 | 
				
			||||||
 | 
									Transport: &http.Transport{
 | 
				
			||||||
 | 
										Proxy: http.ProxyURL(proxy),
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								bot.Dialer = &websocket.Dialer{
 | 
				
			||||||
				Proxy: http.ProxyURL(proxy),
 | 
									Proxy: http.ProxyURL(proxy),
 | 
				
			||||||
			},
 | 
								}
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		discord.Dialer = &websocket.Dialer{
 | 
					 | 
				
			||||||
			Proxy: http.ProxyURL(proxy),
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &Bot{
 | 
						return &Bot{
 | 
				
			||||||
		config:  config,
 | 
							config:  config,
 | 
				
			||||||
		bot:     discord,
 | 
							bot:     bot,
 | 
				
			||||||
		name:    name,
 | 
							name:    name,
 | 
				
			||||||
		service: service,
 | 
							service: service,
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,24 +12,32 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type Client struct {
 | 
					type Client struct {
 | 
				
			||||||
	client *req.Client
 | 
						client *req.Client
 | 
				
			||||||
	config types.MidJourneyConfig
 | 
						Config types.MidJourneyConfig
 | 
				
			||||||
 | 
						apiURL string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
 | 
					func NewClient(config types.MidJourneyConfig, proxy string) *Client {
 | 
				
			||||||
	client := req.C().SetTimeout(10 * time.Second)
 | 
						client := req.C().SetTimeout(10 * time.Second)
 | 
				
			||||||
 | 
						var apiURL string
 | 
				
			||||||
	// set proxy URL
 | 
						// set proxy URL
 | 
				
			||||||
	if proxy != "" {
 | 
						if config.UseCDN {
 | 
				
			||||||
		client.SetProxyURL(proxy)
 | 
							apiURL = config.DiscordAPI + "/api/v9/interactions"
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							apiURL = "https://discord.com/api/v9/interactions"
 | 
				
			||||||
 | 
							if proxy != "" {
 | 
				
			||||||
 | 
								client.SetProxyURL(proxy)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &Client{client: client, config: config}
 | 
					
 | 
				
			||||||
 | 
						return &Client{client: client, Config: config, apiURL: apiURL}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Client) Imagine(prompt string) error {
 | 
					func (c *Client) Imagine(prompt string) error {
 | 
				
			||||||
	interactionsReq := &InteractionsRequest{
 | 
						interactionsReq := &InteractionsRequest{
 | 
				
			||||||
		Type:          2,
 | 
							Type:          2,
 | 
				
			||||||
		ApplicationID: ApplicationID,
 | 
							ApplicationID: ApplicationID,
 | 
				
			||||||
		GuildID:       c.config.GuildId,
 | 
							GuildID:       c.Config.GuildId,
 | 
				
			||||||
		ChannelID:     c.config.ChanelId,
 | 
							ChannelID:     c.Config.ChanelId,
 | 
				
			||||||
		SessionID:     SessionID,
 | 
							SessionID:     SessionID,
 | 
				
			||||||
		Data: map[string]any{
 | 
							Data: map[string]any{
 | 
				
			||||||
			"version": "1166847114203123795",
 | 
								"version": "1166847114203123795",
 | 
				
			||||||
@@ -67,11 +75,10 @@ func (c *Client) Imagine(prompt string) error {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	url := "https://discord.com/api/v9/interactions"
 | 
						r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
 | 
				
			||||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
					 | 
				
			||||||
		SetHeader("Content-Type", "application/json").
 | 
							SetHeader("Content-Type", "application/json").
 | 
				
			||||||
		SetBody(interactionsReq).
 | 
							SetBody(interactionsReq).
 | 
				
			||||||
		Post(url)
 | 
							Post(c.apiURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil || r.IsErrorState() {
 | 
						if err != nil || r.IsErrorState() {
 | 
				
			||||||
		return fmt.Errorf("error with http request: %w%v", err, r.Err)
 | 
							return fmt.Errorf("error with http request: %w%v", err, r.Err)
 | 
				
			||||||
@@ -86,8 +93,8 @@ func (c *Client) Upscale(index int, messageId string, hash string) error {
 | 
				
			|||||||
	interactionsReq := &InteractionsRequest{
 | 
						interactionsReq := &InteractionsRequest{
 | 
				
			||||||
		Type:          3,
 | 
							Type:          3,
 | 
				
			||||||
		ApplicationID: ApplicationID,
 | 
							ApplicationID: ApplicationID,
 | 
				
			||||||
		GuildID:       c.config.GuildId,
 | 
							GuildID:       c.Config.GuildId,
 | 
				
			||||||
		ChannelID:     c.config.ChanelId,
 | 
							ChannelID:     c.Config.ChanelId,
 | 
				
			||||||
		MessageFlags:  &flags,
 | 
							MessageFlags:  &flags,
 | 
				
			||||||
		MessageID:     &messageId,
 | 
							MessageID:     &messageId,
 | 
				
			||||||
		SessionID:     SessionID,
 | 
							SessionID:     SessionID,
 | 
				
			||||||
@@ -98,13 +105,12 @@ func (c *Client) Upscale(index int, messageId string, hash string) error {
 | 
				
			|||||||
		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
							Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	url := "https://discord.com/api/v9/interactions"
 | 
					 | 
				
			||||||
	var res InteractionsResult
 | 
						var res InteractionsResult
 | 
				
			||||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
						r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
 | 
				
			||||||
		SetHeader("Content-Type", "application/json").
 | 
							SetHeader("Content-Type", "application/json").
 | 
				
			||||||
		SetBody(interactionsReq).
 | 
							SetBody(interactionsReq).
 | 
				
			||||||
		SetErrorResult(&res).
 | 
							SetErrorResult(&res).
 | 
				
			||||||
		Post(url)
 | 
							Post(c.apiURL)
 | 
				
			||||||
	if err != nil || r.IsErrorState() {
 | 
						if err != nil || r.IsErrorState() {
 | 
				
			||||||
		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
							return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -118,8 +124,8 @@ func (c *Client) Variation(index int, messageId string, hash string) error {
 | 
				
			|||||||
	interactionsReq := &InteractionsRequest{
 | 
						interactionsReq := &InteractionsRequest{
 | 
				
			||||||
		Type:          3,
 | 
							Type:          3,
 | 
				
			||||||
		ApplicationID: ApplicationID,
 | 
							ApplicationID: ApplicationID,
 | 
				
			||||||
		GuildID:       c.config.GuildId,
 | 
							GuildID:       c.Config.GuildId,
 | 
				
			||||||
		ChannelID:     c.config.ChanelId,
 | 
							ChannelID:     c.Config.ChanelId,
 | 
				
			||||||
		MessageFlags:  &flags,
 | 
							MessageFlags:  &flags,
 | 
				
			||||||
		MessageID:     &messageId,
 | 
							MessageID:     &messageId,
 | 
				
			||||||
		SessionID:     SessionID,
 | 
							SessionID:     SessionID,
 | 
				
			||||||
@@ -130,13 +136,12 @@ func (c *Client) Variation(index int, messageId string, hash string) error {
 | 
				
			|||||||
		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
							Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	url := "https://discord.com/api/v9/interactions"
 | 
					 | 
				
			||||||
	var res InteractionsResult
 | 
						var res InteractionsResult
 | 
				
			||||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
						r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
 | 
				
			||||||
		SetHeader("Content-Type", "application/json").
 | 
							SetHeader("Content-Type", "application/json").
 | 
				
			||||||
		SetBody(interactionsReq).
 | 
							SetBody(interactionsReq).
 | 
				
			||||||
		SetErrorResult(&res).
 | 
							SetErrorResult(&res).
 | 
				
			||||||
		Post(url)
 | 
							Post(c.apiURL)
 | 
				
			||||||
	if err != nil || r.IsErrorState() {
 | 
						if err != nil || r.IsErrorState() {
 | 
				
			||||||
		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
							return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,9 +6,9 @@ import (
 | 
				
			|||||||
	"chatplus/store"
 | 
						"chatplus/store"
 | 
				
			||||||
	"chatplus/store/model"
 | 
						"chatplus/store/model"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/go-redis/redis/v8"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -16,13 +16,16 @@ import (
 | 
				
			|||||||
type ServicePool struct {
 | 
					type ServicePool struct {
 | 
				
			||||||
	services        []*Service
 | 
						services        []*Service
 | 
				
			||||||
	taskQueue       *store.RedisQueue
 | 
						taskQueue       *store.RedisQueue
 | 
				
			||||||
 | 
						notifyQueue     *store.RedisQueue
 | 
				
			||||||
	db              *gorm.DB
 | 
						db              *gorm.DB
 | 
				
			||||||
	uploaderManager *oss.UploaderManager
 | 
						uploaderManager *oss.UploaderManager
 | 
				
			||||||
 | 
						Clients         *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
 | 
					func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
 | 
				
			||||||
	services := make([]*Service, 0)
 | 
						services := make([]*Service, 0)
 | 
				
			||||||
	queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
 | 
						taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
 | 
				
			||||||
 | 
						notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
 | 
				
			||||||
	// create mj client and service
 | 
						// create mj client and service
 | 
				
			||||||
	for k, config := range appConfig.MjConfigs {
 | 
						for k, config := range appConfig.MjConfigs {
 | 
				
			||||||
		if config.Enabled == false {
 | 
							if config.Enabled == false {
 | 
				
			||||||
@@ -33,9 +36,9 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		name := fmt.Sprintf("MjService-%d", k)
 | 
							name := fmt.Sprintf("MjService-%d", k)
 | 
				
			||||||
		// create mj service
 | 
							// create mj service
 | 
				
			||||||
		service := NewService(name, queue, 4, 600, db, client)
 | 
							service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
 | 
				
			||||||
		botName := fmt.Sprintf("MjBot-%d", k)
 | 
							botName := fmt.Sprintf("MjBot-%d", k)
 | 
				
			||||||
		bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
 | 
							bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -54,13 +57,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &ServicePool{
 | 
						return &ServicePool{
 | 
				
			||||||
		taskQueue:       queue,
 | 
							taskQueue:       taskQueue,
 | 
				
			||||||
 | 
							notifyQueue:     notifyQueue,
 | 
				
			||||||
		services:        services,
 | 
							services:        services,
 | 
				
			||||||
		uploaderManager: manager,
 | 
							uploaderManager: manager,
 | 
				
			||||||
		db:              db,
 | 
							db:              db,
 | 
				
			||||||
 | 
							Clients:         types.NewLMap[uint, *types.WsClient](),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *ServicePool) CheckTaskNotify() {
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							for {
 | 
				
			||||||
 | 
								var userId uint
 | 
				
			||||||
 | 
								err := p.notifyQueue.LPop(&userId)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								client := p.Clients.Get(userId)
 | 
				
			||||||
 | 
								err = client.Send([]byte("Task Updated"))
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *ServicePool) DownloadImages() {
 | 
					func (p *ServicePool) DownloadImages() {
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		var items []model.MidJourneyJob
 | 
							var items []model.MidJourneyJob
 | 
				
			||||||
@@ -71,15 +93,21 @@ func (p *ServicePool) DownloadImages() {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// download images
 | 
								// download images
 | 
				
			||||||
			for _, item := range items {
 | 
								for _, v := range items {
 | 
				
			||||||
				imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(item.OrgURL, true)
 | 
									imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					logger.Error("error with download image: ", err)
 | 
										logger.Error("error with download image: ", err)
 | 
				
			||||||
					continue
 | 
										continue
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				item.ImgURL = imgURL
 | 
									v.ImgURL = imgURL
 | 
				
			||||||
				p.db.Updates(&item)
 | 
									p.db.Updates(&v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									client := p.Clients.Get(uint(v.UserId))
 | 
				
			||||||
 | 
									err = client.Send([]byte("Task Updated"))
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			time.Sleep(time.Second * 5)
 | 
								time.Sleep(time.Second * 5)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,6 +15,7 @@ type Service struct {
 | 
				
			|||||||
	name             string  // service name
 | 
						name             string  // service name
 | 
				
			||||||
	client           *Client // MJ client
 | 
						client           *Client // MJ client
 | 
				
			||||||
	taskQueue        *store.RedisQueue
 | 
						taskQueue        *store.RedisQueue
 | 
				
			||||||
 | 
						notifyQueue      *store.RedisQueue
 | 
				
			||||||
	db               *gorm.DB
 | 
						db               *gorm.DB
 | 
				
			||||||
	maxHandleTaskNum int32             // max task number current service can handle
 | 
						maxHandleTaskNum int32             // max task number current service can handle
 | 
				
			||||||
	handledTaskNum   int32             // already handled task number
 | 
						handledTaskNum   int32             // already handled task number
 | 
				
			||||||
@@ -22,11 +23,12 @@ type Service struct {
 | 
				
			|||||||
	taskTimeout      int64
 | 
						taskTimeout      int64
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
 | 
					func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
 | 
				
			||||||
	return &Service{
 | 
						return &Service{
 | 
				
			||||||
		name:             name,
 | 
							name:             name,
 | 
				
			||||||
		db:               db,
 | 
							db:               db,
 | 
				
			||||||
		taskQueue:        queue,
 | 
							taskQueue:        taskQueue,
 | 
				
			||||||
 | 
							notifyQueue:      notifyQueue,
 | 
				
			||||||
		client:           client,
 | 
							client:           client,
 | 
				
			||||||
		taskTimeout:      timeout,
 | 
							taskTimeout:      timeout,
 | 
				
			||||||
		maxHandleTaskNum: maxTaskNum,
 | 
							maxHandleTaskNum: maxTaskNum,
 | 
				
			||||||
@@ -53,9 +55,10 @@ func (s *Service) Run() {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// if it's reference message, check if it's this channel's  message
 | 
							// if it's reference message, check if it's this channel's  message
 | 
				
			||||||
		if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId {
 | 
							if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
 | 
				
			||||||
			s.taskQueue.RPush(task)
 | 
								s.taskQueue.RPush(task)
 | 
				
			||||||
			s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
								s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
				
			||||||
 | 
								s.notifyQueue.RPush(task.UserId)
 | 
				
			||||||
			time.Sleep(time.Second)
 | 
								time.Sleep(time.Second)
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -77,6 +80,7 @@ func (s *Service) Run() {
 | 
				
			|||||||
			logger.Error("绘画任务执行失败:", err)
 | 
								logger.Error("绘画任务执行失败:", err)
 | 
				
			||||||
			// update the task progress
 | 
								// update the task progress
 | 
				
			||||||
			s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
								s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
				
			||||||
 | 
								s.notifyQueue.RPush(task.UserId)
 | 
				
			||||||
			// restore img_call quota
 | 
								// restore img_call quota
 | 
				
			||||||
			s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
 | 
								s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
@@ -134,6 +138,10 @@ func (s *Service) Notify(data CBReq) {
 | 
				
			|||||||
	job.Prompt = data.Prompt
 | 
						job.Prompt = data.Prompt
 | 
				
			||||||
	job.Hash = data.Image.Hash
 | 
						job.Hash = data.Image.Hash
 | 
				
			||||||
	job.OrgURL = data.Image.URL
 | 
						job.OrgURL = data.Image.URL
 | 
				
			||||||
 | 
						if s.client.Config.UseCDN {
 | 
				
			||||||
 | 
							job.UseProxy = true
 | 
				
			||||||
 | 
							job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	res = s.db.Updates(&job)
 | 
						res = s.db.Updates(&job)
 | 
				
			||||||
	if res.Error != nil {
 | 
						if res.Error != nil {
 | 
				
			||||||
@@ -146,4 +154,6 @@ func (s *Service) Notify(data CBReq) {
 | 
				
			|||||||
		atomic.AddInt32(&s.handledTaskNum, -1)
 | 
							atomic.AddInt32(&s.handledTaskNum, -1)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s.notifyQueue.RPush(job.UserId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@ package model
 | 
				
			|||||||
type Function struct {
 | 
					type Function struct {
 | 
				
			||||||
	Id          uint `gorm:"primarykey;column:id"`
 | 
						Id          uint `gorm:"primarykey;column:id"`
 | 
				
			||||||
	Name        string
 | 
						Name        string
 | 
				
			||||||
 | 
						Label       string
 | 
				
			||||||
	Description string
 | 
						Description string
 | 
				
			||||||
	Parameters  string
 | 
						Parameters  string
 | 
				
			||||||
	Required    string
 | 
						Required    string
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,6 +15,7 @@ type MidJourneyJob struct {
 | 
				
			|||||||
	Hash        string // message hash
 | 
						Hash        string // message hash
 | 
				
			||||||
	Progress    int
 | 
						Progress    int
 | 
				
			||||||
	Prompt      string
 | 
						Prompt      string
 | 
				
			||||||
 | 
						UseProxy    bool // 是否使用反代加载图片
 | 
				
			||||||
	CreatedAt   time.Time
 | 
						CreatedAt   time.Time
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,5 +15,6 @@ type MidJourneyJob struct {
 | 
				
			|||||||
	Hash        string    `json:"hash"`
 | 
						Hash        string    `json:"hash"`
 | 
				
			||||||
	Progress    int       `json:"progress"`
 | 
						Progress    int       `json:"progress"`
 | 
				
			||||||
	Prompt      string    `json:"prompt"`
 | 
						Prompt      string    `json:"prompt"`
 | 
				
			||||||
 | 
						UseProxy    bool      `json:"use_proxy"`
 | 
				
			||||||
	CreatedAt   time.Time `json:"created_at"`
 | 
						CreatedAt   time.Time `json:"created_at"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,4 +18,7 @@ ALTER TABLE `chatgpt_functions` ADD UNIQUE(`name`);
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
ALTER TABLE `chatgpt_functions` ADD `enabled` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '是否启用' AFTER `action`;
 | 
					ALTER TABLE `chatgpt_functions` ADD `enabled` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '是否启用' AFTER `action`;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ALTER TABLE `chatgpt_functions` ADD `lebal` VARCHAR(30) NULL COMMENT '函数标签' AFTER `name`;
 | 
					ALTER TABLE `chatgpt_functions` ADD `label` VARCHAR(30) NULL COMMENT '函数标签' AFTER `name`;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ALTER TABLE `chatgpt_mj_jobs` ADD `use_proxy` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '是否使用反代' AFTER `progress`;
 | 
				
			||||||
 | 
					ALTER TABLE `chatgpt_mj_jobs` CHANGE `img_url` `img_url` VARCHAR(400) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT '图片URL';
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								web/public/images/mj/mj-v6.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								web/public/images/mj/mj-v6.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 18 KiB  | 
@@ -359,7 +359,7 @@
 | 
				
			|||||||
              <template #default="scope">
 | 
					              <template #default="scope">
 | 
				
			||||||
                <div class="job-item">
 | 
					                <div class="job-item">
 | 
				
			||||||
                  <el-image
 | 
					                  <el-image
 | 
				
			||||||
                      :src="scope.item.type === 'upscale' ? scope.item['img_url'] + '?imageView2/1/w/480/h/600/q/75' : scope.item['img_url'] + '?imageView2/1/w/480/h/480/q/75'"
 | 
					                      :src="scope.item['thumb_url']"
 | 
				
			||||||
                      :class="scope.item.type === 'upscale' ? 'upscale' : ''" :zoom-rate="1.2"
 | 
					                      :class="scope.item.type === 'upscale' ? 'upscale' : ''" :zoom-rate="1.2"
 | 
				
			||||||
                      :preview-src-list="[scope.item['img_url']]" fit="cover" :initial-index="scope.index"
 | 
					                      :preview-src-list="[scope.item['img_url']]" fit="cover" :initial-index="scope.index"
 | 
				
			||||||
                      loading="lazy" v-if="scope.item.progress > 0">
 | 
					                      loading="lazy" v-if="scope.item.progress > 0">
 | 
				
			||||||
@@ -477,7 +477,8 @@ const rates = [
 | 
				
			|||||||
  {css: "size9-16", value: "9:16", text: "9:16", img: "/images/mj/rate_9_16.png"},
 | 
					  {css: "size9-16", value: "9:16", text: "9:16", img: "/images/mj/rate_9_16.png"},
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
const models = [
 | 
					const models = [
 | 
				
			||||||
  {text: "最新模式MJ-5.2", value: " --v 5.2", img: "/images/mj/mj-v5.2.png"},
 | 
					  {text: "写实模式MJ-6.0", value: " --v 6", img: "/images/mj/mj-v6.png"},
 | 
				
			||||||
 | 
					  {text: "优质模式MJ-5.2", value: " --v 5.2", img: "/images/mj/mj-v5.2.png"},
 | 
				
			||||||
  {text: "优质模式MJ-5.1", value: " --v 5.1", img: "/images/mj/mj-v5.1.jpg"},
 | 
					  {text: "优质模式MJ-5.1", value: " --v 5.1", img: "/images/mj/mj-v5.1.jpg"},
 | 
				
			||||||
  {text: "虚幻模式MJ-5", value: " --v 5", img: "/images/mj/mj-v5.jpg"},
 | 
					  {text: "虚幻模式MJ-5", value: " --v 5", img: "/images/mj/mj-v5.jpg"},
 | 
				
			||||||
  {text: "真实模式MJ-4", value: " --v 4", img: "/images/mj/mj-v4.jpg"},
 | 
					  {text: "真实模式MJ-4", value: " --v 4", img: "/images/mj/mj-v4.jpg"},
 | 
				
			||||||
@@ -489,6 +490,10 @@ const models = [
 | 
				
			|||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const options = [
 | 
					const options = [
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    value: 0,
 | 
				
			||||||
 | 
					    label: '默认'
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
    value: 0.25,
 | 
					    value: 0.25,
 | 
				
			||||||
    label: '普通'
 | 
					    label: '普通'
 | 
				
			||||||
@@ -515,7 +520,7 @@ const params = ref({
 | 
				
			|||||||
  prompt: "",
 | 
					  prompt: "",
 | 
				
			||||||
  neg_prompt: "",
 | 
					  neg_prompt: "",
 | 
				
			||||||
  tile: false,
 | 
					  tile: false,
 | 
				
			||||||
  quality: 0.5
 | 
					  quality: 0
 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const activeName = ref('图生图')
 | 
					const activeName = ref('图生图')
 | 
				
			||||||
@@ -527,6 +532,7 @@ const router = useRouter()
 | 
				
			|||||||
const socket = ref(null)
 | 
					const socket = ref(null)
 | 
				
			||||||
const imgCalls = ref(0)
 | 
					const imgCalls = ref(0)
 | 
				
			||||||
const loading = ref(false)
 | 
					const loading = ref(false)
 | 
				
			||||||
 | 
					const userId = ref(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const rewritePrompt = () => {
 | 
					const rewritePrompt = () => {
 | 
				
			||||||
  loading.value = true
 | 
					  loading.value = true
 | 
				
			||||||
@@ -550,12 +556,40 @@ const translatePrompt = () => {
 | 
				
			|||||||
  })
 | 
					  })
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const connect = () => {
 | 
				
			||||||
 | 
					  let host = process.env.VUE_APP_WS_HOST
 | 
				
			||||||
 | 
					  if (host === '') {
 | 
				
			||||||
 | 
					    if (location.protocol === 'https:') {
 | 
				
			||||||
 | 
					      host = 'wss://' + location.host;
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      host = 'ws://' + location.host;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  const _socket = new WebSocket(host + `/api/mj/client?user_id=${userId.value}`);
 | 
				
			||||||
 | 
					  _socket.addEventListener('open', () => {
 | 
				
			||||||
 | 
					    socket.value = _socket;
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  _socket.addEventListener('message', event => {
 | 
				
			||||||
 | 
					    if (event.data instanceof Blob) {
 | 
				
			||||||
 | 
					      fetchRunningJobs(userId.value)
 | 
				
			||||||
 | 
					      fetchFinishJobs(userId.value)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  _socket.addEventListener('close', () => {
 | 
				
			||||||
 | 
					    connect()
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
onMounted(() => {
 | 
					onMounted(() => {
 | 
				
			||||||
  checkSession().then(user => {
 | 
					  checkSession().then(user => {
 | 
				
			||||||
    imgCalls.value = user['img_calls']
 | 
					    imgCalls.value = user['img_calls']
 | 
				
			||||||
 | 
					    userId.value = user.id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    fetchRunningJobs(user.id)
 | 
					    fetchRunningJobs(userId.value)
 | 
				
			||||||
    fetchFinishJobs(user.id)
 | 
					    fetchFinishJobs(userId.value)
 | 
				
			||||||
 | 
					    connect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  }).catch(() => {
 | 
					  }).catch(() => {
 | 
				
			||||||
    router.push('/login')
 | 
					    router.push('/login')
 | 
				
			||||||
@@ -588,39 +622,29 @@ const fetchRunningJobs = (userId) => {
 | 
				
			|||||||
      _jobs.push(jobs[i])
 | 
					      _jobs.push(jobs[i])
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    runningJobs.value = _jobs
 | 
					    runningJobs.value = _jobs
 | 
				
			||||||
 | 
					 | 
				
			||||||
    setTimeout(() => fetchRunningJobs(userId), 1000)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  }).catch(e => {
 | 
					  }).catch(e => {
 | 
				
			||||||
    ElMessage.error("获取任务失败:" + e.message)
 | 
					    ElMessage.error("获取任务失败:" + e.message)
 | 
				
			||||||
    setTimeout(() => fetchRunningJobs(userId), 5000)
 | 
					 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const fetchFinishJobs = (userId) => {
 | 
					const fetchFinishJobs = (userId) => {
 | 
				
			||||||
  // 获取已完成的任务
 | 
					  // 获取已完成的任务
 | 
				
			||||||
  httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
 | 
					  httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
 | 
				
			||||||
    if (finishedJobs.value.length === 0) {
 | 
					    const jobs = res.data
 | 
				
			||||||
      finishedJobs.value = res.data
 | 
					    for (let i = 0; i < jobs.length; i++) {
 | 
				
			||||||
      return
 | 
					      if (jobs[i]['use_proxy']) {
 | 
				
			||||||
    }
 | 
					        jobs[i]['thumb_url'] = jobs[i]['img_url'] + '?x-oss-process=image/quality,q_60&format=webp'
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
    // check if the img url is changed
 | 
					        if (jobs[i].type === 'upscale') {
 | 
				
			||||||
    const list = res.data
 | 
					          jobs[i]['thumb_url'] = jobs[i]['img_url'] + '?imageView2/1/w/480/h/600/q/75'
 | 
				
			||||||
    let changed = false
 | 
					        } else {
 | 
				
			||||||
    for (let i = 0; i < list.length; i++) {
 | 
					          jobs[i]['thumb_url'] = jobs[i]['img_url'] + '?imageView2/1/w/480/h/480/q/75'
 | 
				
			||||||
      if (list[i]["img_url"] !== finishedJobs.value[i]["img_url"]) {
 | 
					        }
 | 
				
			||||||
        changed = true
 | 
					 | 
				
			||||||
        break
 | 
					 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (changed) {
 | 
					    finishedJobs.value = jobs
 | 
				
			||||||
      finishedJobs.value = list
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    setTimeout(() => fetchFinishJobs(userId), 1000)
 | 
					 | 
				
			||||||
  }).catch(e => {
 | 
					  }).catch(e => {
 | 
				
			||||||
    ElMessage.error("获取任务失败:" + e.message)
 | 
					    ElMessage.error("获取任务失败:" + e.message)
 | 
				
			||||||
    setTimeout(() => fetchFinishJobs(userId), 5000)
 | 
					 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -710,7 +734,7 @@ const removeImage = (item) => {
 | 
				
			|||||||
        type: 'warning',
 | 
					        type: 'warning',
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
  ).then(() => {
 | 
					  ).then(() => {
 | 
				
			||||||
    httpPost("/api/mj/remove", {id: item.id, img_url: item.img_url}).then(() => {
 | 
					    httpPost("/api/mj/remove", {id: item.id, img_url: item.img_url, user_id: userId.value}).then(() => {
 | 
				
			||||||
      ElMessage.success("任务删除成功")
 | 
					      ElMessage.success("任务删除成功")
 | 
				
			||||||
    }).catch(e => {
 | 
					    }).catch(e => {
 | 
				
			||||||
      ElMessage.error("任务删除失败:" + e.message)
 | 
					      ElMessage.error("任务删除失败:" + e.message)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -612,8 +612,9 @@ onMounted(() => {
 | 
				
			|||||||
  // 获取已完成的任务
 | 
					  // 获取已完成的任务
 | 
				
			||||||
  const fetchFinishJobs = (userId) => {
 | 
					  const fetchFinishJobs = (userId) => {
 | 
				
			||||||
    httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
 | 
					    httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
 | 
				
			||||||
      if (finishedJobs.value.length === 0) {
 | 
					      if (finishedJobs.value.length === 0 || res.data.length > finishedJobs.value.length) {
 | 
				
			||||||
        finishedJobs.value = res.data
 | 
					        finishedJobs.value = res.data
 | 
				
			||||||
 | 
					        setTimeout(() => fetchFinishJobs(userId), 1000)
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user