mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	feat: midjourney plus service is ready
This commit is contained in:
		@@ -152,6 +152,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
 | 
				
			|||||||
			c.Request.URL.Path == "/api/role/list" ||
 | 
								c.Request.URL.Path == "/api/role/list" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/mj/jobs" ||
 | 
								c.Request.URL.Path == "/api/mj/jobs" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/mj/client" ||
 | 
								c.Request.URL.Path == "/api/mj/client" ||
 | 
				
			||||||
 | 
								c.Request.URL.Path == "/api/mj/notify" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/invite/hits" ||
 | 
								c.Request.URL.Path == "/api/invite/hits" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/sd/jobs" ||
 | 
								c.Request.URL.Path == "/api/sd/jobs" ||
 | 
				
			||||||
			c.Request.URL.Path == "/api/upload" ||
 | 
								c.Request.URL.Path == "/api/upload" ||
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,21 +5,23 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type AppConfig struct {
 | 
					type AppConfig struct {
 | 
				
			||||||
	Path      string `toml:"-"`
 | 
						Path          string `toml:"-"`
 | 
				
			||||||
	Listen    string
 | 
						Listen        string
 | 
				
			||||||
	Session   Session
 | 
						Session       Session
 | 
				
			||||||
	ProxyURL  string
 | 
						ProxyURL      string
 | 
				
			||||||
	MysqlDns  string                  // mysql 连接地址
 | 
						MysqlDns      string                  // mysql 连接地址
 | 
				
			||||||
	Manager   Manager                 // 后台管理员账户信息
 | 
						Manager       Manager                 // 后台管理员账户信息
 | 
				
			||||||
	StaticDir string                  // 静态资源目录
 | 
						StaticDir     string                  // 静态资源目录
 | 
				
			||||||
	StaticUrl string                  // 静态资源 URL
 | 
						StaticUrl     string                  // 静态资源 URL
 | 
				
			||||||
	Redis     RedisConfig             // redis 连接信息
 | 
						Redis         RedisConfig             // redis 连接信息
 | 
				
			||||||
	ApiConfig ChatPlusApiConfig       // ChatPlus API authorization configs
 | 
						ApiConfig     ChatPlusApiConfig       // ChatPlus API authorization configs
 | 
				
			||||||
	SmsConfig AliYunSmsConfig         // AliYun send message service config
 | 
						SmsConfig     AliYunSmsConfig         // AliYun send message service config
 | 
				
			||||||
	OSS       OSSConfig               // OSS config
 | 
						OSS           OSSConfig               // OSS config
 | 
				
			||||||
	MjConfigs []MidJourneyConfig      // mj AI draw service pool
 | 
						MjConfigs     []MidJourneyConfig      // mj AI draw service pool
 | 
				
			||||||
	WeChatBot bool                    // 是否启用微信机器人
 | 
						MjPlusConfigs []MidJourneyPlusConfig  // MJ plus config
 | 
				
			||||||
	SdConfigs []StableDiffusionConfig // sd AI draw service pool
 | 
						ImgCdnURL     string                  // 图片反代加速地址
 | 
				
			||||||
 | 
						WeChatBot     bool                    // 是否启用微信机器人
 | 
				
			||||||
 | 
						SdConfigs     []StableDiffusionConfig // sd AI draw service pool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	XXLConfig     XXLConfig
 | 
						XXLConfig     XXLConfig
 | 
				
			||||||
	AlipayConfig  AlipayConfig
 | 
						AlipayConfig  AlipayConfig
 | 
				
			||||||
@@ -60,7 +62,6 @@ type MidJourneyConfig struct {
 | 
				
			|||||||
	ChanelId       string // Chanel ID
 | 
						ChanelId       string // Chanel ID
 | 
				
			||||||
	UseCDN         bool
 | 
						UseCDN         bool
 | 
				
			||||||
	DiscordAPI     string
 | 
						DiscordAPI     string
 | 
				
			||||||
	DiscordCDN     string
 | 
					 | 
				
			||||||
	DiscordGateway string
 | 
						DiscordGateway string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -71,6 +72,14 @@ type StableDiffusionConfig struct {
 | 
				
			|||||||
	Txt2ImgJsonPath string
 | 
						Txt2ImgJsonPath string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MidJourneyPlusConfig struct {
 | 
				
			||||||
 | 
						Enabled   bool   // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
 | 
				
			||||||
 | 
						Name      string // 服务名称,保持唯一
 | 
				
			||||||
 | 
						ApiURL    string
 | 
				
			||||||
 | 
						ApiKey    string
 | 
				
			||||||
 | 
						NotifyURL string // 任务进度更新回调地址
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type AliYunSmsConfig struct {
 | 
					type AliYunSmsConfig struct {
 | 
				
			||||||
	AccessKey    string
 | 
						AccessKey    string
 | 
				
			||||||
	AccessSecret string
 | 
						AccessSecret string
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,6 +5,7 @@ import (
 | 
				
			|||||||
	"chatplus/core/types"
 | 
						"chatplus/core/types"
 | 
				
			||||||
	"chatplus/service"
 | 
						"chatplus/service"
 | 
				
			||||||
	"chatplus/service/mj"
 | 
						"chatplus/service/mj"
 | 
				
			||||||
 | 
						"chatplus/service/mj/plus"
 | 
				
			||||||
	"chatplus/service/oss"
 | 
						"chatplus/service/oss"
 | 
				
			||||||
	"chatplus/store/model"
 | 
						"chatplus/store/model"
 | 
				
			||||||
	"chatplus/store/vo"
 | 
						"chatplus/store/vo"
 | 
				
			||||||
@@ -203,7 +204,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
						idValue, _ := c.Get(types.LoginUserID)
 | 
				
			||||||
	jobId := 0
 | 
					 | 
				
			||||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
						userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
				
			||||||
	taskId, _ := h.snowflake.Next(true)
 | 
						taskId, _ := h.snowflake.Next(true)
 | 
				
			||||||
	job := model.MidJourneyJob{
 | 
						job := model.MidJourneyJob{
 | 
				
			||||||
@@ -221,7 +221,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	h.pool.PushTask(types.MjTask{
 | 
						h.pool.PushTask(types.MjTask{
 | 
				
			||||||
		Id:          jobId,
 | 
							Id:          int(job.Id),
 | 
				
			||||||
		SessionId:   data.SessionId,
 | 
							SessionId:   data.SessionId,
 | 
				
			||||||
		Type:        types.TaskUpscale,
 | 
							Type:        types.TaskUpscale,
 | 
				
			||||||
		Prompt:      data.Prompt,
 | 
							Prompt:      data.Prompt,
 | 
				
			||||||
@@ -251,7 +251,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
						idValue, _ := c.Get(types.LoginUserID)
 | 
				
			||||||
	jobId := 0
 | 
					 | 
				
			||||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
						userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
				
			||||||
	taskId, _ := h.snowflake.Next(true)
 | 
						taskId, _ := h.snowflake.Next(true)
 | 
				
			||||||
	job := model.MidJourneyJob{
 | 
						job := model.MidJourneyJob{
 | 
				
			||||||
@@ -270,7 +269,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	h.pool.PushTask(types.MjTask{
 | 
						h.pool.PushTask(types.MjTask{
 | 
				
			||||||
		Id:          jobId,
 | 
							Id:          int(job.Id),
 | 
				
			||||||
		SessionId:   data.SessionId,
 | 
							SessionId:   data.SessionId,
 | 
				
			||||||
		Type:        types.TaskVariation,
 | 
							Type:        types.TaskVariation,
 | 
				
			||||||
		Prompt:      data.Prompt,
 | 
							Prompt:      data.Prompt,
 | 
				
			||||||
@@ -340,9 +339,13 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			// 正在运行中任务使用代理访问图片
 | 
								// 正在运行中任务使用代理访问图片
 | 
				
			||||||
			if item.ImgURL == "" && item.OrgURL != "" {
 | 
								if item.ImgURL == "" && item.OrgURL != "" {
 | 
				
			||||||
				image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
 | 
									if h.App.Config.ImgCdnURL != "" {
 | 
				
			||||||
				if err == nil {
 | 
										job.ImgURL = strings.ReplaceAll(job.OrgURL, "https://cdn.discordapp.com", h.App.Config.ImgCdnURL)
 | 
				
			||||||
					job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
									} else {
 | 
				
			||||||
 | 
										image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
 | 
				
			||||||
 | 
										if err == nil {
 | 
				
			||||||
 | 
											job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -382,3 +385,24 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Notify MidJourney Plus 服务任务回调处理
 | 
				
			||||||
 | 
					func (h *MidJourneyHandler) Notify(c *gin.Context) {
 | 
				
			||||||
 | 
						var data plus.CBReq
 | 
				
			||||||
 | 
						if err := c.ShouldBindJSON(&data); err != nil {
 | 
				
			||||||
 | 
							logger.Error("非法任务回调:%+v", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err := h.pool.Notify(data)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logger.Error(err)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							userId := h.GetLoginUserId(c)
 | 
				
			||||||
 | 
							client := h.pool.Clients.Get(userId)
 | 
				
			||||||
 | 
							if client != nil {
 | 
				
			||||||
 | 
								_ = client.Send([]byte("Task Updated"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,8 +7,8 @@ import (
 | 
				
			|||||||
	"chatplus/utils"
 | 
						"chatplus/utils"
 | 
				
			||||||
	"chatplus/utils/resp"
 | 
						"chatplus/utils/resp"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"github.com/imroc/req/v3"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -22,23 +22,176 @@ func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS
 | 
				
			|||||||
	return &TestHandler{db: db, snowflake: snowflake, js: js}
 | 
						return &TestHandler{db: db, snowflake: snowflake, js: js}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *TestHandler) Test(c *gin.Context) {
 | 
					type reqBody struct {
 | 
				
			||||||
	//h.initUserNickname(c)
 | 
						BotType       string        `json:"botType"`
 | 
				
			||||||
	//h.initMjTaskId(c)
 | 
						Prompt        string        `json:"prompt"`
 | 
				
			||||||
 | 
						Base64Array   []interface{} `json:"base64Array,omitempty"`
 | 
				
			||||||
 | 
						AccountFilter struct {
 | 
				
			||||||
 | 
							InstanceId          string        `json:"instanceId"`
 | 
				
			||||||
 | 
							Modes               []interface{} `json:"modes"`
 | 
				
			||||||
 | 
							Remix               bool          `json:"remix"`
 | 
				
			||||||
 | 
							RemixAutoConsidered bool          `json:"remixAutoConsidered"`
 | 
				
			||||||
 | 
						} `json:"accountFilter,omitempty"`
 | 
				
			||||||
 | 
						NotifyHook string `json:"notifyHook"`
 | 
				
			||||||
 | 
						State      string `json:"state,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	orderId, _ := h.snowflake.Next(false)
 | 
					type resBody struct {
 | 
				
			||||||
	params := payment.JPayReq{
 | 
						Code        int    `json:"code"`
 | 
				
			||||||
		TotalFee:   12345,
 | 
						Description string `json:"description"`
 | 
				
			||||||
		OutTradeNo: orderId,
 | 
						Properties  struct {
 | 
				
			||||||
		Subject:    "支付测试",
 | 
						} `json:"properties"`
 | 
				
			||||||
 | 
						Result string `json:"result"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *TestHandler) Test(c *gin.Context) {
 | 
				
			||||||
 | 
						query(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func upscale(c *gin.Context) {
 | 
				
			||||||
 | 
						apiURL := "https://api.openai1s.cn/mj/submit/action"
 | 
				
			||||||
 | 
						token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
 | 
				
			||||||
 | 
						body := map[string]string{
 | 
				
			||||||
 | 
							"customId":   "MJ::JOB::upsample::1::c80a8eb1-f2d1-4f40-8785-97eb99b7ba0a",
 | 
				
			||||||
 | 
							"taskId":     "1704880156226095",
 | 
				
			||||||
 | 
							"notifyHook": "http://r9it.com:6004/api/test/mj",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	r := h.js.Pay(params)
 | 
						var res resBody
 | 
				
			||||||
	if !r.IsOK() {
 | 
						var resErr errRes
 | 
				
			||||||
		resp.ERROR(c, r.ReturnMsg)
 | 
						r, err := req.C().R().
 | 
				
			||||||
 | 
							SetHeader("Authorization", "Bearer "+token).
 | 
				
			||||||
 | 
							SetBody(body).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).
 | 
				
			||||||
 | 
							SetErrorResult(&resErr).
 | 
				
			||||||
 | 
							Post(apiURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							resp.ERROR(c, "请求出错:"+err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	resp.SUCCESS(c, r)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if r.IsErrorState() {
 | 
				
			||||||
 | 
							resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.SUCCESS(c, res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type queryRes struct {
 | 
				
			||||||
 | 
						Action  string `json:"action"`
 | 
				
			||||||
 | 
						Buttons []struct {
 | 
				
			||||||
 | 
							CustomId string `json:"customId"`
 | 
				
			||||||
 | 
							Emoji    string `json:"emoji"`
 | 
				
			||||||
 | 
							Label    string `json:"label"`
 | 
				
			||||||
 | 
							Style    int    `json:"style"`
 | 
				
			||||||
 | 
							Type     int    `json:"type"`
 | 
				
			||||||
 | 
						} `json:"buttons"`
 | 
				
			||||||
 | 
						Description string `json:"description"`
 | 
				
			||||||
 | 
						FailReason  string `json:"failReason"`
 | 
				
			||||||
 | 
						FinishTime  int    `json:"finishTime"`
 | 
				
			||||||
 | 
						Id          string `json:"id"`
 | 
				
			||||||
 | 
						ImageUrl    string `json:"imageUrl"`
 | 
				
			||||||
 | 
						Progress    string `json:"progress"`
 | 
				
			||||||
 | 
						Prompt      string `json:"prompt"`
 | 
				
			||||||
 | 
						PromptEn    string `json:"promptEn"`
 | 
				
			||||||
 | 
						Properties  struct {
 | 
				
			||||||
 | 
						} `json:"properties"`
 | 
				
			||||||
 | 
						StartTime  int    `json:"startTime"`
 | 
				
			||||||
 | 
						State      string `json:"state"`
 | 
				
			||||||
 | 
						Status     string `json:"status"`
 | 
				
			||||||
 | 
						SubmitTime int    `json:"submitTime"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func query(c *gin.Context) {
 | 
				
			||||||
 | 
						apiURL := "https://api.openai1s.cn/mj/task/1704960661008372/fetch"
 | 
				
			||||||
 | 
						token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
 | 
				
			||||||
 | 
						var res queryRes
 | 
				
			||||||
 | 
						r, err := req.C().R().SetHeader("Authorization", "Bearer "+token).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).
 | 
				
			||||||
 | 
							Get(apiURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							resp.ERROR(c, "请求出错:"+err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if r.IsErrorState() {
 | 
				
			||||||
 | 
							resp.ERROR(c, "返回错误状态:"+r.Status)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.SUCCESS(c, res)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type errRes struct {
 | 
				
			||||||
 | 
						Error struct {
 | 
				
			||||||
 | 
							Message string `json:"message"`
 | 
				
			||||||
 | 
						} `json:"error"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func image(c *gin.Context) {
 | 
				
			||||||
 | 
						apiURL := "https://api.openai1s.cn/mj-fast/mj/submit/imagine"
 | 
				
			||||||
 | 
						token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
 | 
				
			||||||
 | 
						body := reqBody{
 | 
				
			||||||
 | 
							BotType:    "MID_JOURNEY",
 | 
				
			||||||
 | 
							Prompt:     "一个中国美女,手上拿着一桶爆米花,脸上带着迷人的微笑,白色衣服 --s 750 --v 6",
 | 
				
			||||||
 | 
							NotifyHook: "http://r9it.com:6004/api/test/mj",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var res resBody
 | 
				
			||||||
 | 
						var resErr errRes
 | 
				
			||||||
 | 
						r, err := req.C().R().
 | 
				
			||||||
 | 
							SetHeader("Authorization", "Bearer "+token).
 | 
				
			||||||
 | 
							SetBody(body).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).
 | 
				
			||||||
 | 
							SetErrorResult(&resErr).
 | 
				
			||||||
 | 
							Post(apiURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							resp.ERROR(c, "请求出错:"+err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if r.IsErrorState() {
 | 
				
			||||||
 | 
							resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.SUCCESS(c, res)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type cbReq struct {
 | 
				
			||||||
 | 
						Id          string      `json:"id"`
 | 
				
			||||||
 | 
						Action      string      `json:"action"`
 | 
				
			||||||
 | 
						Status      string      `json:"status"`
 | 
				
			||||||
 | 
						Prompt      string      `json:"prompt"`
 | 
				
			||||||
 | 
						PromptEn    string      `json:"promptEn"`
 | 
				
			||||||
 | 
						Description string      `json:"description"`
 | 
				
			||||||
 | 
						SubmitTime  int64       `json:"submitTime"`
 | 
				
			||||||
 | 
						StartTime   int64       `json:"startTime"`
 | 
				
			||||||
 | 
						FinishTime  int64       `json:"finishTime"`
 | 
				
			||||||
 | 
						Progress    string      `json:"progress"`
 | 
				
			||||||
 | 
						ImageUrl    string      `json:"imageUrl"`
 | 
				
			||||||
 | 
						FailReason  interface{} `json:"failReason"`
 | 
				
			||||||
 | 
						Properties  struct {
 | 
				
			||||||
 | 
							FinalPrompt string `json:"finalPrompt"`
 | 
				
			||||||
 | 
						} `json:"properties"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *TestHandler) Mj(c *gin.Context) {
 | 
				
			||||||
 | 
						var data cbReq
 | 
				
			||||||
 | 
						if err := c.ShouldBindJSON(&data); err != nil {
 | 
				
			||||||
 | 
							logger.Error(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						logger.Debugf("任务ID:%s,任务进度:%s,图片地址:%s, 最终提示词:%s", data.Id, data.Progress, data.ImageUrl, data.Properties.FinalPrompt)
 | 
				
			||||||
 | 
						apiURL := "https://api.openai1s.cn/mj/task/" + data.Id + "/fetch"
 | 
				
			||||||
 | 
						token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
 | 
				
			||||||
 | 
						var res queryRes
 | 
				
			||||||
 | 
						_, _ = req.C().R().SetHeader("Authorization", "Bearer "+token).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).
 | 
				
			||||||
 | 
							Get(apiURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Println(res.State, ",", res.ImageUrl, ",", res.Progress)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *TestHandler) initUserNickname(c *gin.Context) {
 | 
					func (h *TestHandler) initUserNickname(c *gin.Context) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -235,6 +235,7 @@ func main() {
 | 
				
			|||||||
			group.POST("variation", h.Variation)
 | 
								group.POST("variation", h.Variation)
 | 
				
			||||||
			group.GET("jobs", h.JobList)
 | 
								group.GET("jobs", h.JobList)
 | 
				
			||||||
			group.POST("remove", h.Remove)
 | 
								group.POST("remove", h.Remove)
 | 
				
			||||||
 | 
								group.POST("notify", h.Notify)
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/sd")
 | 
								group := s.Engine.Group("/api/sd")
 | 
				
			||||||
@@ -367,6 +368,7 @@ func main() {
 | 
				
			|||||||
		fx.Provide(handler.NewTestHandler),
 | 
							fx.Provide(handler.NewTestHandler),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
 | 
				
			||||||
			s.Engine.GET("/api/test", h.Test)
 | 
								s.Engine.GET("/api/test", h.Test)
 | 
				
			||||||
 | 
								s.Engine.POST("/api/test/mj", h.Mj)
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
 | 
							fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
 | 
				
			||||||
			err := s.Run(db)
 | 
								err := s.Run(db)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -33,7 +33,7 @@ func NewBot(name string, proxy string, config types.MidJourneyConfig, service *S
 | 
				
			|||||||
	// use CDN reverse proxy
 | 
						// use CDN reverse proxy
 | 
				
			||||||
	if config.UseCDN {
 | 
						if config.UseCDN {
 | 
				
			||||||
		discordgo.SetEndpointDiscord(config.DiscordAPI)
 | 
							discordgo.SetEndpointDiscord(config.DiscordAPI)
 | 
				
			||||||
		discordgo.SetEndpointCDN(config.DiscordCDN)
 | 
							discordgo.SetEndpointCDN("https://cdn.discordapp.com")
 | 
				
			||||||
		discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
 | 
							discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
 | 
				
			||||||
		bot.MjGateway = config.DiscordGateway + "/"
 | 
							bot.MjGateway = config.DiscordGateway + "/"
 | 
				
			||||||
	} else { // use proxy
 | 
						} else { // use proxy
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,12 +11,13 @@ import (
 | 
				
			|||||||
// MidJourney client
 | 
					// MidJourney client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Client struct {
 | 
					type Client struct {
 | 
				
			||||||
	client *req.Client
 | 
						client    *req.Client
 | 
				
			||||||
	Config types.MidJourneyConfig
 | 
						Config    types.MidJourneyConfig
 | 
				
			||||||
	apiURL string
 | 
						imgCdnURL string
 | 
				
			||||||
 | 
						apiURL    string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
 | 
					func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *Client {
 | 
				
			||||||
	client := req.C().SetTimeout(10 * time.Second)
 | 
						client := req.C().SetTimeout(10 * time.Second)
 | 
				
			||||||
	var apiURL string
 | 
						var apiURL string
 | 
				
			||||||
	// set proxy URL
 | 
						// set proxy URL
 | 
				
			||||||
@@ -29,7 +30,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &Client{client: client, Config: config, apiURL: apiURL}
 | 
						return &Client{client: client, Config: config, apiURL: apiURL, imgCdnURL: imgCdnURL}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *Client) Imagine(prompt string) error {
 | 
					func (c *Client) Imagine(prompt string) error {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										171
									
								
								api/service/mj/plus/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								api/service/mj/plus/client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,171 @@
 | 
				
			|||||||
 | 
					package plus
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"chatplus/core/types"
 | 
				
			||||||
 | 
						logger2 "chatplus/logger"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/imroc/req/v3"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var logger = logger2.GetLogger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Client MidJourney Plus Client
 | 
				
			||||||
 | 
					type Client struct {
 | 
				
			||||||
 | 
						Config types.MidJourneyPlusConfig
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewClient(config types.MidJourneyPlusConfig) *Client {
 | 
				
			||||||
 | 
						return &Client{Config: config}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ImageReq struct {
 | 
				
			||||||
 | 
						BotType       string        `json:"botType"`
 | 
				
			||||||
 | 
						Prompt        string        `json:"prompt"`
 | 
				
			||||||
 | 
						Base64Array   []interface{} `json:"base64Array,omitempty"`
 | 
				
			||||||
 | 
						AccountFilter struct {
 | 
				
			||||||
 | 
							InstanceId          string        `json:"instanceId"`
 | 
				
			||||||
 | 
							Modes               []interface{} `json:"modes"`
 | 
				
			||||||
 | 
							Remix               bool          `json:"remix"`
 | 
				
			||||||
 | 
							RemixAutoConsidered bool          `json:"remixAutoConsidered"`
 | 
				
			||||||
 | 
						} `json:"accountFilter,omitempty"`
 | 
				
			||||||
 | 
						NotifyHook string `json:"notifyHook"`
 | 
				
			||||||
 | 
						State      string `json:"state,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ImageRes struct {
 | 
				
			||||||
 | 
						Code        int    `json:"code"`
 | 
				
			||||||
 | 
						Description string `json:"description"`
 | 
				
			||||||
 | 
						Properties  struct {
 | 
				
			||||||
 | 
						} `json:"properties"`
 | 
				
			||||||
 | 
						Result string `json:"result"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ErrRes struct {
 | 
				
			||||||
 | 
						Error struct {
 | 
				
			||||||
 | 
							Message string `json:"message"`
 | 
				
			||||||
 | 
						} `json:"error"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *Client) Imagine(prompt string) (ImageRes, error) {
 | 
				
			||||||
 | 
						apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.Config.ApiURL)
 | 
				
			||||||
 | 
						body := ImageReq{
 | 
				
			||||||
 | 
							BotType:    "MID_JOURNEY",
 | 
				
			||||||
 | 
							Prompt:     prompt,
 | 
				
			||||||
 | 
							NotifyHook: c.Config.NotifyURL,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var res ImageRes
 | 
				
			||||||
 | 
						var errRes ErrRes
 | 
				
			||||||
 | 
						r, err := req.C().R().
 | 
				
			||||||
 | 
							SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
				
			||||||
 | 
							SetBody(body).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).
 | 
				
			||||||
 | 
							SetErrorResult(&errRes).
 | 
				
			||||||
 | 
							Post(apiURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if r.IsErrorState() {
 | 
				
			||||||
 | 
							return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return res, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Upscale 放大指定的图片
 | 
				
			||||||
 | 
					func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, error) {
 | 
				
			||||||
 | 
						body := map[string]string{
 | 
				
			||||||
 | 
							"customId":   fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
 | 
				
			||||||
 | 
							"taskId":     messageId,
 | 
				
			||||||
 | 
							"notifyHook": c.Config.NotifyURL,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
 | 
				
			||||||
 | 
						var res ImageRes
 | 
				
			||||||
 | 
						var errRes ErrRes
 | 
				
			||||||
 | 
						r, err := req.C().R().
 | 
				
			||||||
 | 
							SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
				
			||||||
 | 
							SetBody(body).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).
 | 
				
			||||||
 | 
							SetErrorResult(&errRes).
 | 
				
			||||||
 | 
							Post(apiURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if r.IsErrorState() {
 | 
				
			||||||
 | 
							return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return res, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
				
			||||||
 | 
					func (c *Client) Variation(index int, messageId string, hash string) (ImageRes, error) {
 | 
				
			||||||
 | 
						body := map[string]string{
 | 
				
			||||||
 | 
							"customId":   fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
 | 
				
			||||||
 | 
							"taskId":     messageId,
 | 
				
			||||||
 | 
							"notifyHook": c.Config.NotifyURL,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
 | 
				
			||||||
 | 
						var res ImageRes
 | 
				
			||||||
 | 
						var errRes ErrRes
 | 
				
			||||||
 | 
						r, err := req.C().R().
 | 
				
			||||||
 | 
							SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
				
			||||||
 | 
							SetBody(body).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).
 | 
				
			||||||
 | 
							SetErrorResult(&errRes).
 | 
				
			||||||
 | 
							Post(apiURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if r.IsErrorState() {
 | 
				
			||||||
 | 
							return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return res, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type QueryRes struct {
 | 
				
			||||||
 | 
						Action  string `json:"action"`
 | 
				
			||||||
 | 
						Buttons []struct {
 | 
				
			||||||
 | 
							CustomId string `json:"customId"`
 | 
				
			||||||
 | 
							Emoji    string `json:"emoji"`
 | 
				
			||||||
 | 
							Label    string `json:"label"`
 | 
				
			||||||
 | 
							Style    int    `json:"style"`
 | 
				
			||||||
 | 
							Type     int    `json:"type"`
 | 
				
			||||||
 | 
						} `json:"buttons"`
 | 
				
			||||||
 | 
						Description string `json:"description"`
 | 
				
			||||||
 | 
						FailReason  string `json:"failReason"`
 | 
				
			||||||
 | 
						FinishTime  int    `json:"finishTime"`
 | 
				
			||||||
 | 
						Id          string `json:"id"`
 | 
				
			||||||
 | 
						ImageUrl    string `json:"imageUrl"`
 | 
				
			||||||
 | 
						Progress    string `json:"progress"`
 | 
				
			||||||
 | 
						Prompt      string `json:"prompt"`
 | 
				
			||||||
 | 
						PromptEn    string `json:"promptEn"`
 | 
				
			||||||
 | 
						Properties  struct {
 | 
				
			||||||
 | 
						} `json:"properties"`
 | 
				
			||||||
 | 
						StartTime  int    `json:"startTime"`
 | 
				
			||||||
 | 
						State      string `json:"state"`
 | 
				
			||||||
 | 
						Status     string `json:"status"`
 | 
				
			||||||
 | 
						SubmitTime int    `json:"submitTime"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *Client) QueryTask(taskId string) (QueryRes, error) {
 | 
				
			||||||
 | 
						apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.Config.ApiURL, taskId)
 | 
				
			||||||
 | 
						var res QueryRes
 | 
				
			||||||
 | 
						r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).
 | 
				
			||||||
 | 
							Get(apiURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return QueryRes{}, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if r.IsErrorState() {
 | 
				
			||||||
 | 
							return QueryRes{}, errors.New("error status:" + r.Status)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return res, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										164
									
								
								api/service/mj/plus/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								api/service/mj/plus/service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,164 @@
 | 
				
			|||||||
 | 
					package plus
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"chatplus/core/types"
 | 
				
			||||||
 | 
						"chatplus/store"
 | 
				
			||||||
 | 
						"chatplus/store/model"
 | 
				
			||||||
 | 
						"chatplus/utils"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Service MJ 绘画服务
 | 
				
			||||||
 | 
					type Service struct {
 | 
				
			||||||
 | 
						name             string  // service name
 | 
				
			||||||
 | 
						Client           *Client // MJ Client
 | 
				
			||||||
 | 
						taskQueue        *store.RedisQueue
 | 
				
			||||||
 | 
						notifyQueue      *store.RedisQueue
 | 
				
			||||||
 | 
						db               *gorm.DB
 | 
				
			||||||
 | 
						maxHandleTaskNum int32             // max task number current service can handle
 | 
				
			||||||
 | 
						handledTaskNum   int32             // already handled task number
 | 
				
			||||||
 | 
						taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout
 | 
				
			||||||
 | 
						taskTimeout      int64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
 | 
				
			||||||
 | 
						return &Service{
 | 
				
			||||||
 | 
							name:             name,
 | 
				
			||||||
 | 
							db:               db,
 | 
				
			||||||
 | 
							taskQueue:        taskQueue,
 | 
				
			||||||
 | 
							notifyQueue:      notifyQueue,
 | 
				
			||||||
 | 
							Client:           client,
 | 
				
			||||||
 | 
							taskTimeout:      timeout,
 | 
				
			||||||
 | 
							maxHandleTaskNum: maxTaskNum,
 | 
				
			||||||
 | 
							taskStartTimes:   make(map[int]time.Time, 0),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Service) Run() {
 | 
				
			||||||
 | 
						logger.Infof("Starting MidJourney job consumer for %s", s.name)
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							s.checkTasks()
 | 
				
			||||||
 | 
							if !s.canHandleTask() {
 | 
				
			||||||
 | 
								// current service is full, can not handle more task
 | 
				
			||||||
 | 
								// waiting for running task finish
 | 
				
			||||||
 | 
								time.Sleep(time.Second * 3)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							var task types.MjTask
 | 
				
			||||||
 | 
							err := s.taskQueue.LPop(&task)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logger.Errorf("taking task with error: %v", err)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// if it's reference message, check if it's this channel's  message
 | 
				
			||||||
 | 
							if task.ChannelId != "" && task.ChannelId != s.Client.Config.Name {
 | 
				
			||||||
 | 
								s.taskQueue.RPush(task)
 | 
				
			||||||
 | 
								time.Sleep(time.Second)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
 | 
				
			||||||
 | 
							var res ImageRes
 | 
				
			||||||
 | 
							switch task.Type {
 | 
				
			||||||
 | 
							case types.TaskImage:
 | 
				
			||||||
 | 
								index := strings.Index(task.Prompt, " ")
 | 
				
			||||||
 | 
								res, err = s.Client.Imagine(task.Prompt[index+1:])
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							case types.TaskUpscale:
 | 
				
			||||||
 | 
								res, err = s.Client.Upscale(task.Index, task.MessageId, task.MessageHash)
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							case types.TaskVariation:
 | 
				
			||||||
 | 
								res, err = s.Client.Variation(task.Index, task.MessageId, task.MessageHash)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if err != nil || (res.Code != 1 && res.Code != 22) {
 | 
				
			||||||
 | 
								logger.Error("绘画任务执行失败:", err)
 | 
				
			||||||
 | 
								// update the task progress
 | 
				
			||||||
 | 
								s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
				
			||||||
 | 
								// 任务失败,通知前端
 | 
				
			||||||
 | 
								s.notifyQueue.RPush(task.UserId)
 | 
				
			||||||
 | 
								// restore img_call quota
 | 
				
			||||||
 | 
								s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// TODO: 任务提交失败,加入队列重试
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							logger.Infof("任务提交成功:%+v", res)
 | 
				
			||||||
 | 
							// lock the task until the execute timeout
 | 
				
			||||||
 | 
							s.taskStartTimes[task.Id] = time.Now()
 | 
				
			||||||
 | 
							atomic.AddInt32(&s.handledTaskNum, 1)
 | 
				
			||||||
 | 
							// 更新任务 ID/频道
 | 
				
			||||||
 | 
							s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{
 | 
				
			||||||
 | 
								"task_id":    res.Result,
 | 
				
			||||||
 | 
								"channel_id": s.Client.Config.Name,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// check if current service instance can handle more task
 | 
				
			||||||
 | 
					func (s *Service) canHandleTask() bool {
 | 
				
			||||||
 | 
						handledNum := atomic.LoadInt32(&s.handledTaskNum)
 | 
				
			||||||
 | 
						return handledNum < s.maxHandleTaskNum
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// remove the expired tasks
 | 
				
			||||||
 | 
					func (s *Service) checkTasks() {
 | 
				
			||||||
 | 
						for k, t := range s.taskStartTimes {
 | 
				
			||||||
 | 
							if time.Now().Unix()-t.Unix() > s.taskTimeout {
 | 
				
			||||||
 | 
								delete(s.taskStartTimes, k)
 | 
				
			||||||
 | 
								atomic.AddInt32(&s.handledTaskNum, -1)
 | 
				
			||||||
 | 
								// delete task from database
 | 
				
			||||||
 | 
								s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type CBReq struct {
 | 
				
			||||||
 | 
						Id          string      `json:"id"`
 | 
				
			||||||
 | 
						Action      string      `json:"action"`
 | 
				
			||||||
 | 
						Status      string      `json:"status"`
 | 
				
			||||||
 | 
						Prompt      string      `json:"prompt"`
 | 
				
			||||||
 | 
						PromptEn    string      `json:"promptEn"`
 | 
				
			||||||
 | 
						Description string      `json:"description"`
 | 
				
			||||||
 | 
						SubmitTime  int64       `json:"submitTime"`
 | 
				
			||||||
 | 
						StartTime   int64       `json:"startTime"`
 | 
				
			||||||
 | 
						FinishTime  int64       `json:"finishTime"`
 | 
				
			||||||
 | 
						Progress    string      `json:"progress"`
 | 
				
			||||||
 | 
						ImageUrl    string      `json:"imageUrl"`
 | 
				
			||||||
 | 
						FailReason  interface{} `json:"failReason"`
 | 
				
			||||||
 | 
						Properties  struct {
 | 
				
			||||||
 | 
							FinalPrompt string `json:"finalPrompt"`
 | 
				
			||||||
 | 
						} `json:"properties"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						job.Progress = utils.IntValue(strings.Replace(data.Progress, "%", "", 1), 0)
 | 
				
			||||||
 | 
						job.Prompt = data.Properties.FinalPrompt
 | 
				
			||||||
 | 
						if data.ImageUrl != "" {
 | 
				
			||||||
 | 
							job.OrgURL = data.ImageUrl
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						job.UseProxy = true
 | 
				
			||||||
 | 
						job.MessageId = data.Id
 | 
				
			||||||
 | 
						logger.Debugf("JOB: %+v", job)
 | 
				
			||||||
 | 
						res := s.db.Updates(&job)
 | 
				
			||||||
 | 
						if res.Error != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("error with update job: %v", res.Error)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if data.Status == "SUCCESS" {
 | 
				
			||||||
 | 
							// release lock task
 | 
				
			||||||
 | 
							atomic.AddInt32(&s.handledTaskNum, -1)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s.notifyQueue.RPush(job.UserId)
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -2,11 +2,13 @@ package mj
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"chatplus/core/types"
 | 
						"chatplus/core/types"
 | 
				
			||||||
 | 
						"chatplus/service/mj/plus"
 | 
				
			||||||
	"chatplus/service/oss"
 | 
						"chatplus/service/oss"
 | 
				
			||||||
	"chatplus/store"
 | 
						"chatplus/store"
 | 
				
			||||||
	"chatplus/store/model"
 | 
						"chatplus/store/model"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"github.com/go-redis/redis/v8"
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
@@ -14,7 +16,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// ServicePool Mj service pool
 | 
					// ServicePool Mj service pool
 | 
				
			||||||
type ServicePool struct {
 | 
					type ServicePool struct {
 | 
				
			||||||
	services        []*Service
 | 
						services        []interface{}
 | 
				
			||||||
	taskQueue       *store.RedisQueue
 | 
						taskQueue       *store.RedisQueue
 | 
				
			||||||
	notifyQueue     *store.RedisQueue
 | 
						notifyQueue     *store.RedisQueue
 | 
				
			||||||
	db              *gorm.DB
 | 
						db              *gorm.DB
 | 
				
			||||||
@@ -23,37 +25,53 @@ type ServicePool struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
 | 
					func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
 | 
				
			||||||
	services := make([]*Service, 0)
 | 
						services := make([]interface{}, 0)
 | 
				
			||||||
	taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
 | 
						taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
 | 
				
			||||||
	notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
 | 
						notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
 | 
				
			||||||
	// create mj client and service
 | 
					
 | 
				
			||||||
	for k, config := range appConfig.MjConfigs {
 | 
						for k, config := range appConfig.MjPlusConfigs {
 | 
				
			||||||
		if config.Enabled == false {
 | 
							if config.Enabled == false {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		// create mj client
 | 
							client := plus.NewClient(config)
 | 
				
			||||||
		client := NewClient(config, appConfig.ProxyURL)
 | 
							name := fmt.Sprintf("MidJourney Plus Service-%d", k)
 | 
				
			||||||
 | 
							servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
 | 
				
			||||||
		name := fmt.Sprintf("MjService-%d", k)
 | 
					 | 
				
			||||||
		// create mj service
 | 
					 | 
				
			||||||
		service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
 | 
					 | 
				
			||||||
		botName := fmt.Sprintf("MjBot-%d", k)
 | 
					 | 
				
			||||||
		bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		err = bot.Run()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		// run mj service
 | 
					 | 
				
			||||||
		go func() {
 | 
							go func() {
 | 
				
			||||||
			service.Run()
 | 
								servicePlus.Run()
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
 | 
							services = append(services, servicePlus)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		services = append(services, service)
 | 
						if len(services) == 0 {
 | 
				
			||||||
 | 
							// create mj client and service
 | 
				
			||||||
 | 
							for k, config := range appConfig.MjConfigs {
 | 
				
			||||||
 | 
								if config.Enabled == false {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								// create mj client
 | 
				
			||||||
 | 
								client := NewClient(config, appConfig.ProxyURL, appConfig.ImgCdnURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								name := fmt.Sprintf("MjService-%d", k)
 | 
				
			||||||
 | 
								// create mj service
 | 
				
			||||||
 | 
								service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
 | 
				
			||||||
 | 
								botName := fmt.Sprintf("MjBot-%d", k)
 | 
				
			||||||
 | 
								bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								err = bot.Run()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// run mj service
 | 
				
			||||||
 | 
								go func() {
 | 
				
			||||||
 | 
									service.Run()
 | 
				
			||||||
 | 
								}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								services = append(services, service)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &ServicePool{
 | 
						return &ServicePool{
 | 
				
			||||||
@@ -94,7 +112,24 @@ func (p *ServicePool) DownloadImages() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			// download images
 | 
								// download images
 | 
				
			||||||
			for _, v := range items {
 | 
								for _, v := range items {
 | 
				
			||||||
				imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
 | 
									if v.OrgURL == "" {
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									var imgURL string
 | 
				
			||||||
 | 
									var err error
 | 
				
			||||||
 | 
									if v.UseProxy {
 | 
				
			||||||
 | 
										if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
 | 
				
			||||||
 | 
											task, _ := servicePlus.Client.QueryTask(v.TaskId)
 | 
				
			||||||
 | 
											if task.ImageUrl != "" {
 | 
				
			||||||
 | 
												imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(task.ImageUrl, false)
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
											if len(task.Buttons) > 0 {
 | 
				
			||||||
 | 
												v.Hash = getImageHash(task.Buttons[0].CustomId)
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					logger.Error("error with download image: ", err)
 | 
										logger.Error("error with download image: ", err)
 | 
				
			||||||
					continue
 | 
										continue
 | 
				
			||||||
@@ -125,3 +160,37 @@ func (p *ServicePool) PushTask(task types.MjTask) {
 | 
				
			|||||||
func (p *ServicePool) HasAvailableService() bool {
 | 
					func (p *ServicePool) HasAvailableService() bool {
 | 
				
			||||||
	return len(p.services) > 0
 | 
						return len(p.services) > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *ServicePool) Notify(data plus.CBReq) error {
 | 
				
			||||||
 | 
						logger.Infof("收到任务回调:%+v", data)
 | 
				
			||||||
 | 
						var job model.MidJourneyJob
 | 
				
			||||||
 | 
						res := p.db.Where("task_id = ?", data.Id).First(&job)
 | 
				
			||||||
 | 
						if res.Error != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("非法任务:%s", data.Id)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
 | 
				
			||||||
 | 
							return servicePlus.Notify(data, job)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *ServicePool) getServicePlus(name string) *plus.Service {
 | 
				
			||||||
 | 
						for _, s := range p.services {
 | 
				
			||||||
 | 
							if servicePlus, ok := s.(*plus.Service); ok {
 | 
				
			||||||
 | 
								if servicePlus.Client.Config.Name == name {
 | 
				
			||||||
 | 
									return servicePlus
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getImageHash(action string) string {
 | 
				
			||||||
 | 
						split := strings.Split(action, "::")
 | 
				
			||||||
 | 
						if len(split) > 5 {
 | 
				
			||||||
 | 
							return split[4]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return split[len(split)-1]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -58,8 +58,6 @@ func (s *Service) Run() {
 | 
				
			|||||||
		// if it's reference message, check if it's this channel's  message
 | 
							// if it's reference message, check if it's this channel's  message
 | 
				
			||||||
		if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
 | 
							if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
 | 
				
			||||||
			s.taskQueue.RPush(task)
 | 
								s.taskQueue.RPush(task)
 | 
				
			||||||
			s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
					 | 
				
			||||||
			s.notifyQueue.RPush(task.UserId)
 | 
					 | 
				
			||||||
			time.Sleep(time.Second)
 | 
								time.Sleep(time.Second)
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -143,7 +141,7 @@ func (s *Service) Notify(data CBReq) {
 | 
				
			|||||||
	job.OrgURL = data.Image.URL
 | 
						job.OrgURL = data.Image.URL
 | 
				
			||||||
	if s.client.Config.UseCDN {
 | 
						if s.client.Config.UseCDN {
 | 
				
			||||||
		job.UseProxy = true
 | 
							job.UseProxy = true
 | 
				
			||||||
		job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN)
 | 
							job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.imgCdnURL)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	res = s.db.Updates(&job)
 | 
						res = s.db.Updates(&job)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,12 @@
 | 
				
			|||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func main() {
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func main() {
 | 
				
			||||||
 | 
						str := "7151109597841850368 一个漂亮的中国女孩,手上拿着一桶爆米花,脸上带着迷人的微笑,电影效果"
 | 
				
			||||||
 | 
						index := strings.Index(str, " ")
 | 
				
			||||||
 | 
						fmt.Println(str[index+1:])
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -700,6 +700,7 @@ const generate = () => {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// 图片放大任务
 | 
					// 图片放大任务
 | 
				
			||||||
const upscale = (index, item) => {
 | 
					const upscale = (index, item) => {
 | 
				
			||||||
 | 
					  console.log(item)
 | 
				
			||||||
  send('/api/mj/upscale', index, item)
 | 
					  send('/api/mj/upscale', index, item)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user