diff --git a/api/core/app_server.go b/api/core/app_server.go index a29c32d9..94837d0c 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -154,6 +154,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/role/list" || c.Request.URL.Path == "/api/mj/jobs" || + c.Request.URL.Path == "/api/mj/client" || c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/sd/jobs" || c.Request.URL.Path == "/api/upload" || diff --git a/api/core/config.go b/api/core/config.go index 5e02c821..bf8a1e63 100644 --- a/api/core/config.go +++ b/api/core/config.go @@ -14,13 +14,12 @@ var logger = logger2.GetLogger() func NewDefaultConfig() *types.AppConfig { return &types.AppConfig{ - Listen: "0.0.0.0:5678", - ProxyURL: "", - Manager: types.Manager{Username: "admin", Password: "admin123"}, - StaticDir: "./static", - StaticUrl: "http://localhost/5678/static", - Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""}, - AesEncryptKey: utils.RandString(24), + Listen: "0.0.0.0:5678", + ProxyURL: "", + Manager: types.Manager{Username: "admin", Password: "admin123"}, + StaticDir: "./static", + StaticUrl: "http://localhost/5678/static", + Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""}, Session: types.Session{ SecretKey: utils.RandString(64), MaxAge: 86400, diff --git a/api/core/types/config.go b/api/core/types/config.go index 99dc4838..2cc6d738 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -5,22 +5,21 @@ import ( ) type AppConfig struct { - Path string `toml:"-"` - Listen string - Session Session - ProxyURL string - MysqlDns string // mysql 连接地址 - Manager Manager // 后台管理员账户信息 - StaticDir string // 静态资源目录 - StaticUrl string // 静态资源 URL - Redis RedisConfig // redis 连接信息 - ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs - AesEncryptKey string - SmsConfig AliYunSmsConfig // AliYun send message service config - OSS OSSConfig // OSS config - MjConfigs []MidJourneyConfig // mj AI draw service pool - WeChatBot bool // 是否启用微信机器人 - SdConfigs []StableDiffusionConfig // sd AI draw service pool + Path string `toml:"-"` + Listen string + Session Session + ProxyURL string + MysqlDns string // mysql 连接地址 + Manager Manager // 后台管理员账户信息 + StaticDir string // 静态资源目录 + StaticUrl string // 静态资源 URL + Redis RedisConfig // redis 连接信息 + ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs + SmsConfig AliYunSmsConfig // AliYun send message service config + OSS OSSConfig // OSS config + MjConfigs []MidJourneyConfig // mj AI draw service pool + WeChatBot bool // 是否启用微信机器人 + SdConfigs []StableDiffusionConfig // sd AI draw service pool XXLConfig XXLConfig AlipayConfig AlipayConfig @@ -34,11 +33,15 @@ type ChatPlusApiConfig struct { } type MidJourneyConfig struct { - Enabled bool - UserToken string - BotToken string - GuildId string // Server ID - ChanelId string // Chanel ID + Enabled bool + UserToken string + BotToken string + GuildId string // Server ID + ChanelId string // Chanel ID + UseCDN bool + DiscordAPI string + DiscordCDN string + DiscordGateway string } type StableDiffusionConfig struct { diff --git a/api/core/types/function.go b/api/core/types/function.go index 62237133..048c54df 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -32,46 +32,21 @@ const ( var InnerFunctions = []Function{ { Name: FuncZaoBao, - Description: "每日早报,获取当天全球的热门新闻事件列表", + Description: "每日早报,获取当天新闻事件列表", Parameters: Parameters{ - Type: "object", - Properties: map[string]Property{ - "text": { - Type: "string", - Description: "", - }, - }, - Required: []string{}, + Type: "object", + Properties: map[string]Property{}, + Required: []string{}, }, }, { Name: FuncWeibo, Description: "新浪微博热搜榜,微博当日热搜榜单", Parameters: Parameters{ - Type: "object", - Properties: map[string]Property{ - "text": { - Type: "string", - Description: "", - }, - }, - Required: []string{}, - }, - }, - - { - Name: FuncHeadLine, - Description: "今日头条,给用户推荐当天的头条新闻,周榜热文", - Parameters: Parameters{ - Type: "object", - Properties: map[string]Property{ - "text": { - Type: "string", - Description: "", - }, - }, - Required: []string{}, + Type: "object", + Properties: map[string]Property{}, + Required: []string{}, }, }, diff --git a/api/core/types/locked_map.go b/api/core/types/locked_map.go index 13915c43..ede72f34 100644 --- a/api/core/types/locked_map.go +++ b/api/core/types/locked_map.go @@ -6,7 +6,7 @@ import ( ) type MKey interface { - string | int + string | int | uint } type MValue interface { *WsClient | *ChatSession | context.CancelFunc | []interface{} diff --git a/api/go.mod b/api/go.mod index a95a236b..98bbc104 100644 --- a/api/go.mod +++ b/api/go.mod @@ -6,7 +6,6 @@ require ( github.com/BurntSushi/toml v1.1.0 github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible - github.com/bwmarrin/discordgo v0.27.1 github.com/eatmoreapple/openwechat v1.2.1 github.com/gin-gonic/gin v1.9.1 github.com/go-redis/redis/v8 v8.11.5 @@ -26,6 +25,8 @@ require ( require github.com/xxl-job/xxl-job-executor-go v1.2.0 +require github.com/bg5t/mydiscordgo v0.28.1 + require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/bytedance/sonic v1.9.1 // indirect diff --git a/api/go.sum b/api/go.sum index b14415c9..e677b543 100644 --- a/api/go.sum +++ b/api/go.sum @@ -7,8 +7,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9 github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= -github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY= -github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= +github.com/bg5t/mydiscordgo v0.28.1 h1:mVH0ZWstVdJffCi/EXJAYQDtXwIKAJYVXLmECu1hEK8= +github.com/bg5t/mydiscordgo v0.28.1/go.mod h1:n3aba73N18k1DzM0t0mGE8rwW3Z+vwTvI8pcsBgxN/8= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= diff --git a/api/handler/admin/function_handler.go b/api/handler/admin/function_handler.go index 8eff4cbd..26a8805f 100644 --- a/api/handler/admin/function_handler.go +++ b/api/handler/admin/function_handler.go @@ -6,7 +6,11 @@ import ( "chatplus/handler" "chatplus/store/model" "chatplus/store/vo" + "chatplus/utils" "chatplus/utils/resp" + + "github.com/golang-jwt/jwt/v5" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -29,13 +33,65 @@ func (h *FunctionHandler) Save(c *gin.Context) { return } - logger.Info(data) + var f = model.Function{ + Id: data.Id, + Name: data.Name, + Label: data.Label, + Description: data.Description, + Parameters: utils.JsonEncode(data.Parameters), + Required: utils.JsonEncode(data.Required), + Action: data.Action, + Token: data.Token, + Enabled: data.Enabled, + } + + res := h.db.Save(&f) + if res.Error != nil { + resp.ERROR(c, "error with save data:"+res.Error.Error()) + return + } + data.Id = f.Id + resp.SUCCESS(c, data) +} + +func (h *FunctionHandler) Set(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Filed string `json:"filed"` + Value interface{} `json:"value"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.db.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) + if res.Error != nil { + resp.ERROR(c, "更新数据库失败!") + return + } resp.SUCCESS(c) } func (h *FunctionHandler) List(c *gin.Context) { + var items []model.Function + res := h.db.Find(&items) + if res.Error != nil { + resp.ERROR(c, "No data found") + return + } - resp.SUCCESS(c) + functions := make([]vo.Function, 0) + for _, v := range items { + var f vo.Function + err := utils.CopyObject(v, &f) + if err != nil { + continue + } + functions = append(functions, f) + } + resp.SUCCESS(c, functions) } func (h *FunctionHandler) Remove(c *gin.Context) { @@ -50,3 +106,20 @@ func (h *FunctionHandler) Remove(c *gin.Context) { } resp.SUCCESS(c) } + +// GenToken generate function api access token +func (h *FunctionHandler) GenToken(c *gin.Context) { + // 创建 token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": 0, + "expired": 0, + }) + tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) + if err != nil { + logger.Error("error with generate token", err) + resp.ERROR(c) + return + } + + resp.SUCCESS(c, tokenString) +} diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index f81df2a9..ce088b99 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -442,7 +442,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf } else { client = http.DefaultClient } - logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model) + logger.Infof("Sending %s request, ApiURL:%s, PROXY: %s, Model: %s", platform, apiURL, proxyURL, req.Model) switch platform { case types.Azure: request.Header.Set("api-key", *apiKey) @@ -452,7 +452,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf if err != nil { return nil, err } - logger.Info(token) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) break case types.Baidu: diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index be39a84b..d4ccd664 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -50,7 +50,7 @@ type xunFeiResp struct { } var Model2URL = map[string]string{ - "generalv1": "1.1", + "general": "v1.1", "generalv2": "v2.1", "generalv3": "v3.1", } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 7b39aeed..ad946e14 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -13,7 +13,9 @@ import ( "encoding/base64" "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "gorm.io/gorm" + "net/http" "strings" "time" ) @@ -58,6 +60,27 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { } +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *MidJourneyHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.pool.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) +} + // Image 创建一个绘画任务 func (h *MidJourneyHandler) Image(c *gin.Context) { var data struct { @@ -147,6 +170,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { UserId: userId, }) + client := h.pool.Clients.Get(uint(job.UserId)) + _ = client.Send([]byte("Task Updated")) + // update user's img calls h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) resp.SUCCESS(c) @@ -205,6 +231,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { MessageId: data.MessageId, MessageHash: data.MessageHash, }) + + client := h.pool.Clients.Get(uint(job.UserId)) + _ = client.Send([]byte("Task Updated")) + resp.SUCCESS(c) } @@ -226,6 +256,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { job := model.MidJourneyJob{ Type: types.TaskVariation.String(), + ChannelId: data.ChannelId, ReferenceId: data.MessageId, UserId: userId, TaskId: data.TaskId, @@ -250,6 +281,9 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { MessageHash: data.MessageHash, }) + client := h.pool.Clients.Get(uint(job.UserId)) + _ = client.Send([]byte("Task Updated")) + // update user's img calls h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) resp.SUCCESS(c) @@ -320,6 +354,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { func (h *MidJourneyHandler) Remove(c *gin.Context) { var data struct { Id uint `json:"id"` + UserId uint `json:"user_id"` ImgURL string `json:"img_url"` } if err := c.ShouldBindJSON(&data); err != nil { @@ -340,5 +375,8 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { logger.Error("remove image failed: ", err) } + client := h.pool.Clients.Get(data.UserId) + _ = client.Send([]byte("Task Updated")) + resp.SUCCESS(c) } diff --git a/api/main.go b/api/main.go index 9d76db9c..721f461d 100644 --- a/api/main.go +++ b/api/main.go @@ -168,6 +168,7 @@ func main() { fx.Invoke(func(pool *mj.ServicePool) { if pool.HasAvailableService() { pool.DownloadImages() + pool.CheckTaskNotify() } }), @@ -234,6 +235,7 @@ func main() { }), fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { group := s.Engine.Group("/api/mj/") + group.Any("client", h.Client) group.POST("image", h.Image) group.POST("upscale", h.Upscale) group.POST("variation", h.Variation) @@ -350,8 +352,10 @@ func main() { fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) { group := s.Engine.Group("/api/admin/function/") group.POST("save", h.Save) + group.POST("set", h.Set) group.GET("list", h.List) group.GET("remove", h.Remove) + group.GET("token", h.GenToken) }), fx.Provide(handler.NewTestHandler), diff --git a/api/service/mj/bot.go b/api/service/mj/bot.go index 2d78b65c..912edf29 100644 --- a/api/service/mj/bot.go +++ b/api/service/mj/bot.go @@ -4,7 +4,7 @@ import ( "chatplus/core/types" logger2 "chatplus/logger" "chatplus/utils" - "github.com/bwmarrin/discordgo" + discordgo "github.com/bg5t/mydiscordgo" "github.com/gorilla/websocket" "net/http" "net/url" @@ -17,33 +17,48 @@ import ( var logger = logger2.GetLogger() type Bot struct { - config *types.MidJourneyConfig + config types.MidJourneyConfig bot *discordgo.Session name string service *Service } -func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) { - discord, err := discordgo.New("Bot " + config.BotToken) +func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) { + bot, err := discordgo.New("Bot " + config.BotToken) if err != nil { + logger.Error(err) return nil, err } - if proxy != "" { - proxy, _ := url.Parse(proxy) - discord.Client = &http.Client{ - Transport: &http.Transport{ + // use CDN reverse proxy + if config.UseCDN { + discordgo.SetEndpointDiscord(config.DiscordAPI) + discordgo.SetEndpointCDN(config.DiscordCDN) + discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/") + bot.MjGateway = config.DiscordGateway + "/" + } else { // use proxy + discordgo.SetEndpointDiscord("https://discord.com") + discordgo.SetEndpointCDN("https://cdn.discordapp.com") + discordgo.SetEndpointStatus("https://discord.com/api/v2/") + bot.MjGateway = "wss://gateway.discord.gg" + + if proxy != "" { + proxy, _ := url.Parse(proxy) + bot.Client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + bot.Dialer = &websocket.Dialer{ Proxy: http.ProxyURL(proxy), - }, - } - discord.Dialer = &websocket.Dialer{ - Proxy: http.ProxyURL(proxy), + } } + } return &Bot{ config: config, - bot: discord, + bot: bot, name: name, service: service, }, nil diff --git a/api/service/mj/client.go b/api/service/mj/client.go index 7eb15bb4..eb84240f 100644 --- a/api/service/mj/client.go +++ b/api/service/mj/client.go @@ -12,25 +12,32 @@ import ( type Client struct { client *req.Client - config types.MidJourneyConfig + Config types.MidJourneyConfig + apiURL string } func NewClient(config types.MidJourneyConfig, proxy string) *Client { client := req.C().SetTimeout(10 * time.Second) + var apiURL string // set proxy URL - if proxy != "" { - client.SetProxyURL(proxy) + if config.UseCDN { + apiURL = config.DiscordAPI + "/api/v9/interactions" + } else { + apiURL = "https://discord.com/api/v9/interactions" + if proxy != "" { + client.SetProxyURL(proxy) + } } - logger.Info(config) - return &Client{client: client, config: config} + + return &Client{client: client, Config: config, apiURL: apiURL} } func (c *Client) Imagine(prompt string) error { interactionsReq := &InteractionsRequest{ Type: 2, ApplicationID: ApplicationID, - GuildID: c.config.GuildId, - ChannelID: c.config.ChanelId, + GuildID: c.Config.GuildId, + ChannelID: c.Config.ChanelId, SessionID: SessionID, Data: map[string]any{ "version": "1166847114203123795", @@ -68,11 +75,10 @@ func (c *Client) Imagine(prompt string) error { }, } - url := "https://discord.com/api/v9/interactions" - r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). + r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). SetHeader("Content-Type", "application/json"). SetBody(interactionsReq). - Post(url) + Post(c.apiURL) if err != nil || r.IsErrorState() { return fmt.Errorf("error with http request: %w%v", err, r.Err) @@ -87,8 +93,8 @@ func (c *Client) Upscale(index int, messageId string, hash string) error { interactionsReq := &InteractionsRequest{ Type: 3, ApplicationID: ApplicationID, - GuildID: c.config.GuildId, - ChannelID: c.config.ChanelId, + GuildID: c.Config.GuildId, + ChannelID: c.Config.ChanelId, MessageFlags: &flags, MessageID: &messageId, SessionID: SessionID, @@ -99,13 +105,12 @@ func (c *Client) Upscale(index int, messageId string, hash string) error { Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), } - url := "https://discord.com/api/v9/interactions" var res InteractionsResult - r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). + r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). SetHeader("Content-Type", "application/json"). SetBody(interactionsReq). SetErrorResult(&res). - Post(url) + Post(c.apiURL) if err != nil || r.IsErrorState() { return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) } @@ -119,8 +124,8 @@ func (c *Client) Variation(index int, messageId string, hash string) error { interactionsReq := &InteractionsRequest{ Type: 3, ApplicationID: ApplicationID, - GuildID: c.config.GuildId, - ChannelID: c.config.ChanelId, + GuildID: c.Config.GuildId, + ChannelID: c.Config.ChanelId, MessageFlags: &flags, MessageID: &messageId, SessionID: SessionID, @@ -131,13 +136,12 @@ func (c *Client) Variation(index int, messageId string, hash string) error { Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), } - url := "https://discord.com/api/v9/interactions" var res InteractionsResult - r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). + r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). SetHeader("Content-Type", "application/json"). SetBody(interactionsReq). SetErrorResult(&res). - Post(url) + Post(c.apiURL) if err != nil || r.IsErrorState() { return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) } diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index 6fd5fd48..b446ad23 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -6,9 +6,9 @@ import ( "chatplus/store" "chatplus/store/model" "fmt" + "github.com/go-redis/redis/v8" "time" - "github.com/go-redis/redis/v8" "gorm.io/gorm" ) @@ -16,13 +16,16 @@ import ( type ServicePool struct { services []*Service taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue db *gorm.DB uploaderManager *oss.UploaderManager + Clients *types.LMap[uint, *types.WsClient] // UserId => Client } func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { services := make([]*Service, 0) - queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) + taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) + notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) // create mj client and service for k, config := range appConfig.MjConfigs { if config.Enabled == false { @@ -33,9 +36,9 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa name := fmt.Sprintf("MjService-%d", k) // create mj service - service := NewService(name, queue, 4, 600, db, client) + service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client) botName := fmt.Sprintf("MjBot-%d", k) - bot, err := NewBot(botName, appConfig.ProxyURL, &config, service) + bot, err := NewBot(botName, appConfig.ProxyURL, config, service) if err != nil { continue } @@ -54,13 +57,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa } return &ServicePool{ - taskQueue: queue, + taskQueue: taskQueue, + notifyQueue: notifyQueue, services: services, uploaderManager: manager, db: db, + Clients: types.NewLMap[uint, *types.WsClient](), } } +func (p *ServicePool) CheckTaskNotify() { + go func() { + for { + var userId uint + err := p.notifyQueue.LPop(&userId) + if err != nil { + continue + } + client := p.Clients.Get(userId) + err = client.Send([]byte("Task Updated")) + if err != nil { + continue + } + } + }() +} + func (p *ServicePool) DownloadImages() { go func() { var items []model.MidJourneyJob @@ -71,15 +93,21 @@ func (p *ServicePool) DownloadImages() { } // download images - for _, item := range items { - imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(item.OrgURL, true) + for _, v := range items { + imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true) if err != nil { logger.Error("error with download image: ", err) continue } - item.ImgURL = imgURL - p.db.Updates(&item) + v.ImgURL = imgURL + p.db.Updates(&v) + + client := p.Clients.Get(uint(v.UserId)) + err = client.Send([]byte("Task Updated")) + if err != nil { + continue + } } time.Sleep(time.Second * 5) diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 3b3d6e7c..04cd94d8 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -15,6 +15,7 @@ type Service struct { name string // service name client *Client // MJ client taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue db *gorm.DB maxHandleTaskNum int32 // max task number current service can handle handledTaskNum int32 // already handled task number @@ -22,11 +23,12 @@ type Service struct { taskTimeout int64 } -func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { +func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { return &Service{ name: name, db: db, - taskQueue: queue, + taskQueue: taskQueue, + notifyQueue: notifyQueue, client: client, taskTimeout: timeout, maxHandleTaskNum: maxTaskNum, @@ -53,9 +55,10 @@ func (s *Service) Run() { } // if it's reference message, check if it's this channel's message - if task.ChannelId != "" && task.ChannelId != s.client.config.ChanelId { + if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId { s.taskQueue.RPush(task) s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) + s.notifyQueue.RPush(task.UserId) time.Sleep(time.Second) continue } @@ -77,6 +80,7 @@ func (s *Service) Run() { logger.Error("绘画任务执行失败:", err) // update the task progress s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) + s.notifyQueue.RPush(task.UserId) // restore img_call quota s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) continue @@ -134,6 +138,10 @@ func (s *Service) Notify(data CBReq) { job.Prompt = data.Prompt job.Hash = data.Image.Hash job.OrgURL = data.Image.URL + if s.client.Config.UseCDN { + job.UseProxy = true + job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN) + } res = s.db.Updates(&job) if res.Error != nil { @@ -146,4 +154,6 @@ func (s *Service) Notify(data CBReq) { atomic.AddInt32(&s.handledTaskNum, -1) } + s.notifyQueue.RPush(job.UserId) + } diff --git a/api/store/model/function.go b/api/store/model/function.go index b737d18c..098750c5 100644 --- a/api/store/model/function.go +++ b/api/store/model/function.go @@ -3,9 +3,11 @@ package model type Function struct { Id uint `gorm:"primarykey;column:id"` Name string + Label string Description string Parameters string Required string Action string + Token string Enabled bool } diff --git a/api/store/model/mj_job.go b/api/store/model/mj_job.go index 488c3b20..bb949060 100644 --- a/api/store/model/mj_job.go +++ b/api/store/model/mj_job.go @@ -15,6 +15,7 @@ type MidJourneyJob struct { Hash string // message hash Progress int Prompt string + UseProxy bool // 是否使用反代加载图片 CreatedAt time.Time } diff --git a/api/store/vo/function.go b/api/store/vo/function.go index 09f1f7bc..afa1d705 100644 --- a/api/store/vo/function.go +++ b/api/store/vo/function.go @@ -14,9 +14,11 @@ type Property struct { type Function struct { Id uint `json:"id"` Name string `json:"name"` + Label string `json:"label"` Description string `json:"description"` Parameters Parameters `json:"parameters"` Required []string `json:"required"` Action string `json:"action"` + Token string `json:"token"` Enabled bool `json:"enabled"` } diff --git a/api/store/vo/mj_job.go b/api/store/vo/mj_job.go index bfc236ce..3ffcb376 100644 --- a/api/store/vo/mj_job.go +++ b/api/store/vo/mj_job.go @@ -15,5 +15,6 @@ type MidJourneyJob struct { Hash string `json:"hash"` Progress int `json:"progress"` Prompt string `json:"prompt"` + UseProxy bool `json:"use_proxy"` CreatedAt time.Time `json:"created_at"` } diff --git a/database/update-v3.2.3.sql b/database/update-v3.2.3.sql index e64c6ccc..4c5f917e 100644 --- a/database/update-v3.2.3.sql +++ b/database/update-v3.2.3.sql @@ -18,4 +18,9 @@ ALTER TABLE `chatgpt_functions` ADD UNIQUE(`name`); ALTER TABLE `chatgpt_functions` ADD `enabled` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '是否启用' AFTER `action`; -ALTER TABLE `chatgpt_functions` ADD `lebal` VARCHAR(30) NULL COMMENT '函数标签' AFTER `name`; \ No newline at end of file +ALTER TABLE `chatgpt_functions` ADD `label` VARCHAR(30) NULL COMMENT '函数标签' AFTER `name`; + +ALTER TABLE `chatgpt_mj_jobs` ADD `use_proxy` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '是否使用反代' AFTER `progress`; +ALTER TABLE `chatgpt_mj_jobs` CHANGE `img_url` `img_url` VARCHAR(400) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT '图片URL'; + +ALTER TABLE `chatgpt_functions` ADD `token` VARCHAR(255) NULL COMMENT 'API授权token' AFTER `action`; diff --git a/web/public/images/mj/mj-v6.png b/web/public/images/mj/mj-v6.png new file mode 100644 index 00000000..2bd6f30e Binary files /dev/null and b/web/public/images/mj/mj-v6.png differ diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index c5d0e960..d7993788 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -359,7 +359,7 @@