diff --git a/api/core/app_server.go b/api/core/app_server.go index 90f48c06..d7bf5590 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -152,6 +152,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { c.Request.URL.Path == "/api/role/list" || c.Request.URL.Path == "/api/mj/jobs" || c.Request.URL.Path == "/api/mj/client" || + c.Request.URL.Path == "/api/mj/notify" || c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/sd/jobs" || c.Request.URL.Path == "/api/upload" || diff --git a/api/core/types/config.go b/api/core/types/config.go index 51e773b2..c8ade652 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -5,21 +5,23 @@ import ( ) type AppConfig struct { - Path string `toml:"-"` - Listen string - Session Session - ProxyURL string - MysqlDns string // mysql 连接地址 - Manager Manager // 后台管理员账户信息 - StaticDir string // 静态资源目录 - StaticUrl string // 静态资源 URL - Redis RedisConfig // redis 连接信息 - ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs - SmsConfig AliYunSmsConfig // AliYun send message service config - OSS OSSConfig // OSS config - MjConfigs []MidJourneyConfig // mj AI draw service pool - WeChatBot bool // 是否启用微信机器人 - SdConfigs []StableDiffusionConfig // sd AI draw service pool + Path string `toml:"-"` + Listen string + Session Session + ProxyURL string + MysqlDns string // mysql 连接地址 + Manager Manager // 后台管理员账户信息 + StaticDir string // 静态资源目录 + StaticUrl string // 静态资源 URL + Redis RedisConfig // redis 连接信息 + ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs + SmsConfig AliYunSmsConfig // AliYun send message service config + OSS OSSConfig // OSS config + MjConfigs []MidJourneyConfig // mj AI draw service pool + MjPlusConfigs []MidJourneyPlusConfig // MJ plus config + ImgCdnURL string // 图片反代加速地址 + WeChatBot bool // 是否启用微信机器人 + SdConfigs []StableDiffusionConfig // sd AI draw service pool XXLConfig XXLConfig AlipayConfig AlipayConfig @@ -60,7 +62,6 @@ type MidJourneyConfig struct { ChanelId string // Chanel ID UseCDN bool DiscordAPI string - DiscordCDN string DiscordGateway string } @@ -71,6 +72,14 @@ type StableDiffusionConfig struct { Txt2ImgJsonPath string } +type MidJourneyPlusConfig struct { + Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 + Name string // 服务名称,保持唯一 + ApiURL string + ApiKey string + NotifyURL string // 任务进度更新回调地址 +} + type AliYunSmsConfig struct { AccessKey string AccessSecret string diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 94bb7ea4..b49a08c9 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -5,6 +5,7 @@ import ( "chatplus/core/types" "chatplus/service" "chatplus/service/mj" + "chatplus/service/mj/plus" "chatplus/service/oss" "chatplus/store/model" "chatplus/store/vo" @@ -203,7 +204,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { } idValue, _ := c.Get(types.LoginUserID) - jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) taskId, _ := h.snowflake.Next(true) job := model.MidJourneyJob{ @@ -221,7 +221,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { } h.pool.PushTask(types.MjTask{ - Id: jobId, + Id: int(job.Id), SessionId: data.SessionId, Type: types.TaskUpscale, Prompt: data.Prompt, @@ -251,7 +251,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { } idValue, _ := c.Get(types.LoginUserID) - jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) taskId, _ := h.snowflake.Next(true) job := model.MidJourneyJob{ @@ -270,7 +269,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { } h.pool.PushTask(types.MjTask{ - Id: jobId, + Id: int(job.Id), SessionId: data.SessionId, Type: types.TaskVariation, Prompt: data.Prompt, @@ -340,9 +339,13 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { // 正在运行中任务使用代理访问图片 if item.ImgURL == "" && item.OrgURL != "" { - image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) - if err == nil { - job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + if h.App.Config.ImgCdnURL != "" { + job.ImgURL = strings.ReplaceAll(job.OrgURL, "https://cdn.discordapp.com", h.App.Config.ImgCdnURL) + } else { + image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) + if err == nil { + job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } } } } @@ -382,3 +385,24 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { resp.SUCCESS(c) } + +// Notify MidJourney Plus 服务任务回调处理 +func (h *MidJourneyHandler) Notify(c *gin.Context) { + var data plus.CBReq + if err := c.ShouldBindJSON(&data); err != nil { + logger.Error("非法任务回调:%+v", err) + return + } + err := h.pool.Notify(data) + if err != nil { + logger.Error(err) + } else { + userId := h.GetLoginUserId(c) + client := h.pool.Clients.Get(userId) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + } + + resp.SUCCESS(c) +} diff --git a/api/handler/test_handler.go b/api/handler/test_handler.go index 3cc1e749..1c13f11f 100644 --- a/api/handler/test_handler.go +++ b/api/handler/test_handler.go @@ -7,8 +7,8 @@ import ( "chatplus/utils" "chatplus/utils/resp" "fmt" - "github.com/gin-gonic/gin" + "github.com/imroc/req/v3" "gorm.io/gorm" ) @@ -22,23 +22,176 @@ func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS return &TestHandler{db: db, snowflake: snowflake, js: js} } -func (h *TestHandler) Test(c *gin.Context) { - //h.initUserNickname(c) - //h.initMjTaskId(c) +type reqBody struct { + BotType string `json:"botType"` + Prompt string `json:"prompt"` + Base64Array []interface{} `json:"base64Array,omitempty"` + AccountFilter struct { + InstanceId string `json:"instanceId"` + Modes []interface{} `json:"modes"` + Remix bool `json:"remix"` + RemixAutoConsidered bool `json:"remixAutoConsidered"` + } `json:"accountFilter,omitempty"` + NotifyHook string `json:"notifyHook"` + State string `json:"state,omitempty"` +} - orderId, _ := h.snowflake.Next(false) - params := payment.JPayReq{ - TotalFee: 12345, - OutTradeNo: orderId, - Subject: "支付测试", +type resBody struct { + Code int `json:"code"` + Description string `json:"description"` + Properties struct { + } `json:"properties"` + Result string `json:"result"` +} + +func (h *TestHandler) Test(c *gin.Context) { + query(c) + +} + +func upscale(c *gin.Context) { + apiURL := "https://api.openai1s.cn/mj/submit/action" + token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a" + body := map[string]string{ + "customId": "MJ::JOB::upsample::1::c80a8eb1-f2d1-4f40-8785-97eb99b7ba0a", + "taskId": "1704880156226095", + "notifyHook": "http://r9it.com:6004/api/test/mj", } - r := h.js.Pay(params) - if !r.IsOK() { - resp.ERROR(c, r.ReturnMsg) + var res resBody + var resErr errRes + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+token). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&resErr). + Post(apiURL) + if err != nil { + resp.ERROR(c, "请求出错:"+err.Error()) return } - resp.SUCCESS(c, r) + if r.IsErrorState() { + resp.ERROR(c, "返回错误状态:"+resErr.Error.Message) + return + } + + resp.SUCCESS(c, res) + +} + +type queryRes struct { + Action string `json:"action"` + Buttons []struct { + CustomId string `json:"customId"` + Emoji string `json:"emoji"` + Label string `json:"label"` + Style int `json:"style"` + Type int `json:"type"` + } `json:"buttons"` + Description string `json:"description"` + FailReason string `json:"failReason"` + FinishTime int `json:"finishTime"` + Id string `json:"id"` + ImageUrl string `json:"imageUrl"` + Progress string `json:"progress"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Properties struct { + } `json:"properties"` + StartTime int `json:"startTime"` + State string `json:"state"` + Status string `json:"status"` + SubmitTime int `json:"submitTime"` +} + +func query(c *gin.Context) { + apiURL := "https://api.openai1s.cn/mj/task/1704960661008372/fetch" + token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a" + var res queryRes + r, err := req.C().R().SetHeader("Authorization", "Bearer "+token). + SetSuccessResult(&res). + Get(apiURL) + + if err != nil { + resp.ERROR(c, "请求出错:"+err.Error()) + return + } + + if r.IsErrorState() { + resp.ERROR(c, "返回错误状态:"+r.Status) + return + } + + resp.SUCCESS(c, res) +} + +type errRes struct { + Error struct { + Message string `json:"message"` + } `json:"error"` +} + +func image(c *gin.Context) { + apiURL := "https://api.openai1s.cn/mj-fast/mj/submit/imagine" + token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a" + body := reqBody{ + BotType: "MID_JOURNEY", + Prompt: "一个中国美女,手上拿着一桶爆米花,脸上带着迷人的微笑,白色衣服 --s 750 --v 6", + NotifyHook: "http://r9it.com:6004/api/test/mj", + } + var res resBody + var resErr errRes + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+token). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&resErr). + Post(apiURL) + if err != nil { + resp.ERROR(c, "请求出错:"+err.Error()) + return + } + + if r.IsErrorState() { + resp.ERROR(c, "返回错误状态:"+resErr.Error.Message) + return + } + + resp.SUCCESS(c, res) +} + +type cbReq struct { + Id string `json:"id"` + Action string `json:"action"` + Status string `json:"status"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + Progress string `json:"progress"` + ImageUrl string `json:"imageUrl"` + FailReason interface{} `json:"failReason"` + Properties struct { + FinalPrompt string `json:"finalPrompt"` + } `json:"properties"` +} + +func (h *TestHandler) Mj(c *gin.Context) { + var data cbReq + if err := c.ShouldBindJSON(&data); err != nil { + logger.Error(err) + } + logger.Debugf("任务ID:%s,任务进度:%s,图片地址:%s, 最终提示词:%s", data.Id, data.Progress, data.ImageUrl, data.Properties.FinalPrompt) + apiURL := "https://api.openai1s.cn/mj/task/" + data.Id + "/fetch" + token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a" + var res queryRes + _, _ = req.C().R().SetHeader("Authorization", "Bearer "+token). + SetSuccessResult(&res). + Get(apiURL) + + fmt.Println(res.State, ",", res.ImageUrl, ",", res.Progress) } func (h *TestHandler) initUserNickname(c *gin.Context) { diff --git a/api/main.go b/api/main.go index 879c93fb..a3cc6b83 100644 --- a/api/main.go +++ b/api/main.go @@ -235,6 +235,7 @@ func main() { group.POST("variation", h.Variation) group.GET("jobs", h.JobList) group.POST("remove", h.Remove) + group.POST("notify", h.Notify) }), fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { group := s.Engine.Group("/api/sd") @@ -367,6 +368,7 @@ func main() { fx.Provide(handler.NewTestHandler), fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) { s.Engine.GET("/api/test", h.Test) + s.Engine.POST("/api/test/mj", h.Mj) }), fx.Invoke(func(s *core.AppServer, db *gorm.DB) { err := s.Run(db) diff --git a/api/service/mj/bot.go b/api/service/mj/bot.go index beb70639..14ee8368 100644 --- a/api/service/mj/bot.go +++ b/api/service/mj/bot.go @@ -33,7 +33,7 @@ func NewBot(name string, proxy string, config types.MidJourneyConfig, service *S // use CDN reverse proxy if config.UseCDN { discordgo.SetEndpointDiscord(config.DiscordAPI) - discordgo.SetEndpointCDN(config.DiscordCDN) + discordgo.SetEndpointCDN("https://cdn.discordapp.com") discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/") bot.MjGateway = config.DiscordGateway + "/" } else { // use proxy diff --git a/api/service/mj/client.go b/api/service/mj/client.go index eb84240f..1540f285 100644 --- a/api/service/mj/client.go +++ b/api/service/mj/client.go @@ -11,12 +11,13 @@ import ( // MidJourney client type Client struct { - client *req.Client - Config types.MidJourneyConfig - apiURL string + client *req.Client + Config types.MidJourneyConfig + imgCdnURL string + apiURL string } -func NewClient(config types.MidJourneyConfig, proxy string) *Client { +func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *Client { client := req.C().SetTimeout(10 * time.Second) var apiURL string // set proxy URL @@ -29,7 +30,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client { } } - return &Client{client: client, Config: config, apiURL: apiURL} + return &Client{client: client, Config: config, apiURL: apiURL, imgCdnURL: imgCdnURL} } func (c *Client) Imagine(prompt string) error { diff --git a/api/service/mj/plus/client.go b/api/service/mj/plus/client.go new file mode 100644 index 00000000..45ba3658 --- /dev/null +++ b/api/service/mj/plus/client.go @@ -0,0 +1,171 @@ +package plus + +import ( + "chatplus/core/types" + logger2 "chatplus/logger" + "errors" + "fmt" + "github.com/imroc/req/v3" +) + +var logger = logger2.GetLogger() + +// Client MidJourney Plus Client +type Client struct { + Config types.MidJourneyPlusConfig +} + +func NewClient(config types.MidJourneyPlusConfig) *Client { + return &Client{Config: config} +} + +type ImageReq struct { + BotType string `json:"botType"` + Prompt string `json:"prompt"` + Base64Array []interface{} `json:"base64Array,omitempty"` + AccountFilter struct { + InstanceId string `json:"instanceId"` + Modes []interface{} `json:"modes"` + Remix bool `json:"remix"` + RemixAutoConsidered bool `json:"remixAutoConsidered"` + } `json:"accountFilter,omitempty"` + NotifyHook string `json:"notifyHook"` + State string `json:"state,omitempty"` +} + +type ImageRes struct { + Code int `json:"code"` + Description string `json:"description"` + Properties struct { + } `json:"properties"` + Result string `json:"result"` +} + +type ErrRes struct { + Error struct { + Message string `json:"message"` + } `json:"error"` +} + +func (c *Client) Imagine(prompt string) (ImageRes, error) { + apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.Config.ApiURL) + body := ImageReq{ + BotType: "MID_JOURNEY", + Prompt: prompt, + NotifyHook: c.Config.NotifyURL, + } + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// Upscale 放大指定的图片 +func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, error) { + body := map[string]string{ + "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash), + "taskId": messageId, + "notifyHook": c.Config.NotifyURL, + } + apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 +func (c *Client) Variation(index int, messageId string, hash string) (ImageRes, error) { + body := map[string]string{ + "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash), + "taskId": messageId, + "notifyHook": c.Config.NotifyURL, + } + apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +type QueryRes struct { + Action string `json:"action"` + Buttons []struct { + CustomId string `json:"customId"` + Emoji string `json:"emoji"` + Label string `json:"label"` + Style int `json:"style"` + Type int `json:"type"` + } `json:"buttons"` + Description string `json:"description"` + FailReason string `json:"failReason"` + FinishTime int `json:"finishTime"` + Id string `json:"id"` + ImageUrl string `json:"imageUrl"` + Progress string `json:"progress"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Properties struct { + } `json:"properties"` + StartTime int `json:"startTime"` + State string `json:"state"` + Status string `json:"status"` + SubmitTime int `json:"submitTime"` +} + +func (c *Client) QueryTask(taskId string) (QueryRes, error) { + apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.Config.ApiURL, taskId) + var res QueryRes + r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetSuccessResult(&res). + Get(apiURL) + + if err != nil { + return QueryRes{}, err + } + + if r.IsErrorState() { + return QueryRes{}, errors.New("error status:" + r.Status) + } + + return res, nil +} diff --git a/api/service/mj/plus/service.go b/api/service/mj/plus/service.go new file mode 100644 index 00000000..788d4355 --- /dev/null +++ b/api/service/mj/plus/service.go @@ -0,0 +1,164 @@ +package plus + +import ( + "chatplus/core/types" + "chatplus/store" + "chatplus/store/model" + "chatplus/utils" + "fmt" + "strings" + "sync/atomic" + "time" + + "gorm.io/gorm" +) + +// Service MJ 绘画服务 +type Service struct { + name string // service name + Client *Client // MJ Client + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + db *gorm.DB + maxHandleTaskNum int32 // max task number current service can handle + handledTaskNum int32 // already handled task number + taskStartTimes map[int]time.Time // task start time, to check if the task is timeout + taskTimeout int64 +} + +func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { + return &Service{ + name: name, + db: db, + taskQueue: taskQueue, + notifyQueue: notifyQueue, + Client: client, + taskTimeout: timeout, + maxHandleTaskNum: maxTaskNum, + taskStartTimes: make(map[int]time.Time, 0), + } +} + +func (s *Service) Run() { + logger.Infof("Starting MidJourney job consumer for %s", s.name) + for { + s.checkTasks() + if !s.canHandleTask() { + // current service is full, can not handle more task + // waiting for running task finish + time.Sleep(time.Second * 3) + continue + } + + var task types.MjTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + + // if it's reference message, check if it's this channel's message + if task.ChannelId != "" && task.ChannelId != s.Client.Config.Name { + s.taskQueue.RPush(task) + time.Sleep(time.Second) + continue + } + + logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) + var res ImageRes + switch task.Type { + case types.TaskImage: + index := strings.Index(task.Prompt, " ") + res, err = s.Client.Imagine(task.Prompt[index+1:]) + break + case types.TaskUpscale: + res, err = s.Client.Upscale(task.Index, task.MessageId, task.MessageHash) + break + case types.TaskVariation: + res, err = s.Client.Variation(task.Index, task.MessageId, task.MessageHash) + } + + if err != nil || (res.Code != 1 && res.Code != 22) { + logger.Error("绘画任务执行失败:", err) + // update the task progress + s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) + // 任务失败,通知前端 + s.notifyQueue.RPush(task.UserId) + // restore img_call quota + s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) + + // TODO: 任务提交失败,加入队列重试 + continue + } + logger.Infof("任务提交成功:%+v", res) + // lock the task until the execute timeout + s.taskStartTimes[task.Id] = time.Now() + atomic.AddInt32(&s.handledTaskNum, 1) + // 更新任务 ID/频道 + s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{ + "task_id": res.Result, + "channel_id": s.Client.Config.Name, + }) + + } +} + +// check if current service instance can handle more task +func (s *Service) canHandleTask() bool { + handledNum := atomic.LoadInt32(&s.handledTaskNum) + return handledNum < s.maxHandleTaskNum +} + +// remove the expired tasks +func (s *Service) checkTasks() { + for k, t := range s.taskStartTimes { + if time.Now().Unix()-t.Unix() > s.taskTimeout { + delete(s.taskStartTimes, k) + atomic.AddInt32(&s.handledTaskNum, -1) + // delete task from database + s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") + } + } +} + +type CBReq struct { + Id string `json:"id"` + Action string `json:"action"` + Status string `json:"status"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + Progress string `json:"progress"` + ImageUrl string `json:"imageUrl"` + FailReason interface{} `json:"failReason"` + Properties struct { + FinalPrompt string `json:"finalPrompt"` + } `json:"properties"` +} + +func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error { + + job.Progress = utils.IntValue(strings.Replace(data.Progress, "%", "", 1), 0) + job.Prompt = data.Properties.FinalPrompt + if data.ImageUrl != "" { + job.OrgURL = data.ImageUrl + } + job.UseProxy = true + job.MessageId = data.Id + logger.Debugf("JOB: %+v", job) + res := s.db.Updates(&job) + if res.Error != nil { + return fmt.Errorf("error with update job: %v", res.Error) + } + + if data.Status == "SUCCESS" { + // release lock task + atomic.AddInt32(&s.handledTaskNum, -1) + } + + s.notifyQueue.RPush(job.UserId) + return nil +} diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index b446ad23..b904640a 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -2,11 +2,13 @@ package mj import ( "chatplus/core/types" + "chatplus/service/mj/plus" "chatplus/service/oss" "chatplus/store" "chatplus/store/model" "fmt" "github.com/go-redis/redis/v8" + "strings" "time" "gorm.io/gorm" @@ -14,7 +16,7 @@ import ( // ServicePool Mj service pool type ServicePool struct { - services []*Service + services []interface{} taskQueue *store.RedisQueue notifyQueue *store.RedisQueue db *gorm.DB @@ -23,37 +25,53 @@ type ServicePool struct { } func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { - services := make([]*Service, 0) + services := make([]interface{}, 0) taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) - // create mj client and service - for k, config := range appConfig.MjConfigs { + + for k, config := range appConfig.MjPlusConfigs { if config.Enabled == false { continue } - // create mj client - client := NewClient(config, appConfig.ProxyURL) - - name := fmt.Sprintf("MjService-%d", k) - // create mj service - service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client) - botName := fmt.Sprintf("MjBot-%d", k) - bot, err := NewBot(botName, appConfig.ProxyURL, config, service) - if err != nil { - continue - } - - err = bot.Run() - if err != nil { - continue - } - - // run mj service + client := plus.NewClient(config) + name := fmt.Sprintf("MidJourney Plus Service-%d", k) + servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client) go func() { - service.Run() + servicePlus.Run() }() + services = append(services, servicePlus) + } - services = append(services, service) + if len(services) == 0 { + // create mj client and service + for k, config := range appConfig.MjConfigs { + if config.Enabled == false { + continue + } + // create mj client + client := NewClient(config, appConfig.ProxyURL, appConfig.ImgCdnURL) + + name := fmt.Sprintf("MjService-%d", k) + // create mj service + service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client) + botName := fmt.Sprintf("MjBot-%d", k) + bot, err := NewBot(botName, appConfig.ProxyURL, config, service) + if err != nil { + continue + } + + err = bot.Run() + if err != nil { + continue + } + + // run mj service + go func() { + service.Run() + }() + + services = append(services, service) + } } return &ServicePool{ @@ -94,7 +112,24 @@ func (p *ServicePool) DownloadImages() { // download images for _, v := range items { - imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true) + if v.OrgURL == "" { + continue + } + var imgURL string + var err error + if v.UseProxy { + if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil { + task, _ := servicePlus.Client.QueryTask(v.TaskId) + if task.ImageUrl != "" { + imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(task.ImageUrl, false) + } + if len(task.Buttons) > 0 { + v.Hash = getImageHash(task.Buttons[0].CustomId) + } + } + } else { + imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true) + } if err != nil { logger.Error("error with download image: ", err) continue @@ -125,3 +160,37 @@ func (p *ServicePool) PushTask(task types.MjTask) { func (p *ServicePool) HasAvailableService() bool { return len(p.services) > 0 } + +func (p *ServicePool) Notify(data plus.CBReq) error { + logger.Infof("收到任务回调:%+v", data) + var job model.MidJourneyJob + res := p.db.Where("task_id = ?", data.Id).First(&job) + if res.Error != nil { + return fmt.Errorf("非法任务:%s", data.Id) + } + + if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil { + return servicePlus.Notify(data, job) + } + + return nil +} + +func (p *ServicePool) getServicePlus(name string) *plus.Service { + for _, s := range p.services { + if servicePlus, ok := s.(*plus.Service); ok { + if servicePlus.Client.Config.Name == name { + return servicePlus + } + } + } + return nil +} + +func getImageHash(action string) string { + split := strings.Split(action, "::") + if len(split) > 5 { + return split[4] + } + return split[len(split)-1] +} diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 1c03de6d..8ec8dcac 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -58,8 +58,6 @@ func (s *Service) Run() { // if it's reference message, check if it's this channel's message if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId { s.taskQueue.RPush(task) - s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) - s.notifyQueue.RPush(task.UserId) time.Sleep(time.Second) continue } @@ -143,7 +141,7 @@ func (s *Service) Notify(data CBReq) { job.OrgURL = data.Image.URL if s.client.Config.UseCDN { job.UseProxy = true - job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN) + job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.imgCdnURL) } res = s.db.Updates(&job) diff --git a/api/test/test.go b/api/test/test.go index 79058077..008479b8 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,5 +1,12 @@ package main -func main() { +import ( + "fmt" + "strings" +) +func main() { + str := "7151109597841850368 一个漂亮的中国女孩,手上拿着一桶爆米花,脸上带着迷人的微笑,电影效果" + index := strings.Index(str, " ") + fmt.Println(str[index+1:]) } diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index c021c17f..c81fce86 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -700,6 +700,7 @@ const generate = () => { // 图片放大任务 const upscale = (index, item) => { + console.log(item) send('/api/mj/upscale', index, item) }