diff --git a/api/core/types/task.go b/api/core/types/task.go index f8a3b48b..cb22c395 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -9,6 +9,8 @@ func (t TaskType) String() string { const ( TaskImage = TaskType("image") + TaskBlend = TaskType("blend") + TaskSwapFace = TaskType("swapFace") TaskUpscale = TaskType("upscale") TaskVariation = TaskType("variation") ) @@ -16,6 +18,8 @@ const ( // MjTask MidJourney 任务 type MjTask struct { Id int `json:"id"` + TaskId string `json:"task_id"` + ImgArr []string `json:"img_arr"` ChannelId string `json:"channel_id"` SessionId string `json:"session_id"` Type TaskType `json:"type"` diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index ea22f9c1..89995d98 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -86,19 +86,20 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { // Image 创建一个绘画任务 func (h *MidJourneyHandler) Image(c *gin.Context) { var data struct { - SessionId string `json:"session_id"` - Prompt string `json:"prompt"` - NegPrompt string `json:"neg_prompt"` - Rate string `json:"rate"` - Model string `json:"model"` - Chaos int `json:"chaos"` - Raw bool `json:"raw"` - Seed int64 `json:"seed"` - Stylize int `json:"stylize"` - Img string `json:"img"` - Tile bool `json:"tile"` - Quality float32 `json:"quality"` - Weight float32 `json:"weight"` + SessionId string `json:"session_id"` + TaskType string `json:"task_type"` + Prompt string `json:"prompt"` + NegPrompt string `json:"neg_prompt"` + Rate string `json:"rate"` + Model string `json:"model"` + Chaos int `json:"chaos"` + Raw bool `json:"raw"` + Seed int64 `json:"seed"` + Stylize int `json:"stylize"` + ImgArr []string `json:"img_arr"` + Tile bool `json:"tile"` + Quality float32 `json:"quality"` + Weight float32 `json:"weight"` } if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) @@ -121,11 +122,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { prompt += fmt.Sprintf(" --c %d", data.Chaos) } - if data.Img != "" { - prompt = fmt.Sprintf("%s %s", data.Img, prompt) - if data.Weight > 0 { - prompt += fmt.Sprintf(" --iw %f", data.Weight) - } + if data.Weight > 0 { + prompt += fmt.Sprintf(" --iw %f", data.Weight) } if data.Raw { prompt += " --style raw" @@ -152,7 +150,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { return } job := model.MidJourneyJob{ - Type: types.TaskImage.String(), + Type: data.TaskType, UserId: userId, TaskId: taskId, Progress: 0, @@ -166,10 +164,12 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { h.pool.PushTask(types.MjTask{ Id: int(job.Id), + TaskId: taskId, SessionId: data.SessionId, Type: types.TaskImage, - Prompt: fmt.Sprintf("%s %s", taskId, prompt), + Prompt: prompt, UserId: userId, + ImgArr: data.ImgArr, }) client := h.pool.Clients.Get(uint(job.UserId)) diff --git a/api/service/mj/client.go b/api/service/mj/client.go index 1540f285..bd557628 100644 --- a/api/service/mj/client.go +++ b/api/service/mj/client.go @@ -2,6 +2,7 @@ package mj import ( "chatplus/core/types" + "errors" "fmt" "time" @@ -33,7 +34,7 @@ func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *C return &Client{client: client, Config: config, apiURL: apiURL, imgCdnURL: imgCdnURL} } -func (c *Client) Imagine(prompt string) error { +func (c *Client) Imagine(task types.MjTask) error { interactionsReq := &InteractionsRequest{ Type: 2, ApplicationID: ApplicationID, @@ -49,7 +50,7 @@ func (c *Client) Imagine(prompt string) error { { "type": 3, "name": "prompt", - "value": prompt, + "value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt), }, }, "application_command": map[string]any{ @@ -88,20 +89,28 @@ func (c *Client) Imagine(prompt string) error { return nil } +func (c *Client) Blend(task types.MjTask) error { + return errors.New("function not implemented") +} + +func (c *Client) SwapFace(task types.MjTask) error { + return errors.New("function not implemented") +} + // Upscale 放大指定的图片 -func (c *Client) Upscale(index int, messageId string, hash string) error { +func (c *Client) Upscale(task types.MjTask) error { flags := 0 interactionsReq := &InteractionsRequest{ Type: 3, ApplicationID: ApplicationID, GuildID: c.Config.GuildId, ChannelID: c.Config.ChanelId, - MessageFlags: &flags, - MessageID: &messageId, + MessageFlags: flags, + MessageID: task.MessageId, SessionID: SessionID, Data: map[string]any{ "component_type": 2, - "custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash), + "custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), }, Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), } @@ -120,19 +129,19 @@ func (c *Client) Upscale(index int, messageId string, hash string) error { } // Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 -func (c *Client) Variation(index int, messageId string, hash string) error { +func (c *Client) Variation(task types.MjTask) error { flags := 0 interactionsReq := &InteractionsRequest{ Type: 3, ApplicationID: ApplicationID, GuildID: c.Config.GuildId, ChannelID: c.Config.ChanelId, - MessageFlags: &flags, - MessageID: &messageId, + MessageFlags: flags, + MessageID: task.MessageId, SessionID: SessionID, Data: map[string]any{ "component_type": 2, - "custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash), + "custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), }, Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), } diff --git a/api/service/mj/plus/client.go b/api/service/mj/plus/client.go index 5511898a..45ba1481 100644 --- a/api/service/mj/plus/client.go +++ b/api/service/mj/plus/client.go @@ -3,8 +3,11 @@ package plus import ( "chatplus/core/types" logger2 "chatplus/logger" + "chatplus/utils" + "encoding/base64" "errors" "fmt" + "github.com/gin-gonic/gin" "io" "github.com/imroc/req/v3" @@ -22,9 +25,10 @@ func NewClient(config types.MidJourneyPlusConfig) *Client { } type ImageReq struct { - BotType string `json:"botType"` - Prompt string `json:"prompt"` - Base64Array []interface{} `json:"base64Array,omitempty"` + BotType string `json:"botType"` + Prompt string `json:"prompt,omitempty"` + Dimensions string `json:"dimensions,omitempty"` + Base64Array []string `json:"base64Array,omitempty"` AccountFilter struct { InstanceId string `json:"instanceId"` Modes []interface{} `json:"modes"` @@ -49,12 +53,114 @@ type ErrRes struct { } `json:"error"` } -func (c *Client) Imagine(prompt string) (ImageRes, error) { +func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.Config.ApiURL) body := ImageReq{ - BotType: "MID_JOURNEY", - Prompt: prompt, - NotifyHook: c.Config.NotifyURL, + BotType: "MID_JOURNEY", + Prompt: task.Prompt, + NotifyHook: c.Config.NotifyURL, + Base64Array: make([]string, 1), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + imageData, err := utils.DownloadImage(task.ImgArr[0], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array[0] = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + } + + } + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + errStr, _ := io.ReadAll(r.Body) + return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr)) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// Blend 融图 +func (c *Client) Blend(task types.MjTask) (ImageRes, error) { + apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/blend", c.Config.ApiURL) + body := ImageReq{ + BotType: "MID_JOURNEY", + Dimensions: "SQUARE", + NotifyHook: c.Config.NotifyURL, + Base64Array: make([]string, 1), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + for _, imgURL := range task.ImgArr { + imageData, err := utils.DownloadImage(imgURL, "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array[0] = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + } + } + } + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + errStr, _ := io.ReadAll(r.Body) + return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr)) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// SwapFace 换脸 +func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { + apiURL := fmt.Sprintf("%s/mj-fast/mj/insight-face/swap", c.Config.ApiURL) + // 生成图片 Base64 编码 + if len(task.ImgArr) != 2 { + return ImageRes{}, errors.New("参数错误,必须上传2张图片") + } + var sourceBase64 string + var targetBase64 string + imageData, err := utils.DownloadImage(task.ImgArr[0], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + } + imageData, err = utils.DownloadImage(task.ImgArr[1], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + } + + body := gin.H{ + "sourceBase64": sourceBase64, + "targetBase64": targetBase64, + "accountFilter": gin.H{ + "instanceId": "", + }, + "notifyHook": c.Config.NotifyURL, + "state": "", } var res ImageRes var errRes ErrRes @@ -77,10 +183,10 @@ func (c *Client) Imagine(prompt string) (ImageRes, error) { } // Upscale 放大指定的图片 -func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, error) { +func (c *Client) Upscale(task types.MjTask) (ImageRes, error) { body := map[string]string{ - "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash), - "taskId": messageId, + "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), + "taskId": task.MessageId, "notifyHook": c.Config.NotifyURL, } apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL) @@ -104,10 +210,10 @@ func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, er } // Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 -func (c *Client) Variation(index int, messageId string, hash string) (ImageRes, error) { +func (c *Client) Variation(task types.MjTask) (ImageRes, error) { body := map[string]string{ - "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash), - "taskId": messageId, + "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), + "taskId": task.MessageId, "notifyHook": c.Config.NotifyURL, } apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL) diff --git a/api/service/mj/plus/service.go b/api/service/mj/plus/service.go index f77f6271..3b0d03c4 100644 --- a/api/service/mj/plus/service.go +++ b/api/service/mj/plus/service.go @@ -69,18 +69,24 @@ func (s *Service) Run() { var res ImageRes switch task.Type { case types.TaskImage: - index := strings.Index(task.Prompt, " ") - res, err = s.Client.Imagine(task.Prompt[index+1:]) + res, err = s.Client.Imagine(task) break case types.TaskUpscale: - res, err = s.Client.Upscale(task.Index, task.MessageId, task.MessageHash) + res, err = s.Client.Upscale(task) break case types.TaskVariation: - res, err = s.Client.Variation(task.Index, task.MessageId, task.MessageHash) + res, err = s.Client.Variation(task) + break + case types.TaskBlend: + res, err = s.Client.Blend(task) + break + case types.TaskSwapFace: + res, err = s.Client.SwapFace(task) + break } if err != nil || (res.Code != 1 && res.Code != 22) { - logger.Error("绘画任务执行失败:", err) + logger.Error("绘画任务执行失败:", err, res.Description) // update the task progress s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1) // 任务失败,通知前端 diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 0b332807..23b8a04d 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -65,14 +65,20 @@ func (s *Service) Run() { logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) switch task.Type { case types.TaskImage: - err = s.client.Imagine(task.Prompt) + err = s.client.Imagine(task) break case types.TaskUpscale: - err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash) - + err = s.client.Upscale(task) break case types.TaskVariation: - err = s.client.Variation(task.Index, task.MessageId, task.MessageHash) + err = s.client.Variation(task) + break + case types.TaskBlend: + err = s.client.Blend(task) + break + case types.TaskSwapFace: + err = s.client.SwapFace(task) + break } if err != nil { diff --git a/api/service/mj/types.go b/api/service/mj/types.go index ec367210..ff6a5dd3 100644 --- a/api/service/mj/types.go +++ b/api/service/mj/types.go @@ -8,8 +8,8 @@ const ( type InteractionsRequest struct { Type int `json:"type"` ApplicationID string `json:"application_id"` - MessageFlags *int `json:"message_flags,omitempty"` - MessageID *string `json:"message_id,omitempty"` + MessageFlags int `json:"message_flags,omitempty"` + MessageID string `json:"message_id,omitempty"` GuildID string `json:"guild_id"` ChannelID string `json:"channel_id"` SessionID string `json:"session_id"` diff --git a/api/utils/net.go b/api/utils/net.go index 1746c2ea..58319dc6 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -48,12 +48,12 @@ func DownloadImage(imageURL string, proxy string) ([]byte, error) { }, } } - req, err := http.NewRequest("GET", imageURL, nil) + request, err := http.NewRequest("GET", imageURL, nil) if err != nil { return nil, err } - resp, err := client.Do(req) + resp, err := client.Do(request) if err != nil { return nil, err } diff --git a/web/src/assets/css/image-mj.css b/web/src/assets/css/image-mj.css index 44f552bc..cdb9197e 100644 --- a/web/src/assets/css/image-mj.css +++ b/web/src/assets/css/image-mj.css @@ -201,9 +201,6 @@ .page-mj .inner .task-list-box .task-list-inner .title-tabs .el-tabs__active-bar { background-color: #47fff1; } -.page-mj .inner .task-list-box .task-list-inner .title-tabs .el-tabs__content { - padding: 10px 0; -} .page-mj .inner .task-list-box .task-list-inner .el-textarea { --el-input-focus-border-color: #47fff1; } @@ -254,6 +251,12 @@ height: 120px; text-align: center; } +.page-mj .inner .task-list-box .task-list-inner .img-inline { + display: flex; +} +.page-mj .inner .task-list-box .task-list-inner .img-inline .img-uploader { + margin-right: 10px; +} .page-mj .inner .task-list-box .task-list-inner .submit-btn { display: flex; margin: 20px 0; diff --git a/web/src/assets/css/image-sd.css b/web/src/assets/css/image-sd.css index edb58873..15251961 100644 --- a/web/src/assets/css/image-sd.css +++ b/web/src/assets/css/image-sd.css @@ -86,9 +86,6 @@ .page-sd .inner .task-list-box .task-list-inner .title-tabs .el-tabs__active-bar { background-color: #47fff1; } -.page-sd .inner .task-list-box .task-list-inner .title-tabs .el-tabs__content { - padding: 10px 0; -} .page-sd .inner .task-list-box .task-list-inner .el-textarea { --el-input-focus-border-color: #47fff1; } @@ -139,6 +136,12 @@ height: 120px; text-align: center; } +.page-sd .inner .task-list-box .task-list-inner .img-inline { + display: flex; +} +.page-sd .inner .task-list-box .task-list-inner .img-inline .img-uploader { + margin-right: 10px; +} .page-sd .inner .task-list-box .task-list-inner .submit-btn { display: flex; margin: 20px 0; diff --git a/web/src/assets/css/mobile/image-mj.css b/web/src/assets/css/mobile/image-mj.css new file mode 100644 index 00000000..162e688a --- /dev/null +++ b/web/src/assets/css/mobile/image-mj.css @@ -0,0 +1,4 @@ +.mobile-mj .content .van-field__label { + width: 100px; + text-align: right; +} diff --git a/web/src/assets/css/mobile/image-mj.styl b/web/src/assets/css/mobile/image-mj.styl new file mode 100644 index 00000000..8be71238 --- /dev/null +++ b/web/src/assets/css/mobile/image-mj.styl @@ -0,0 +1,8 @@ +.mobile-mj { + .content { + .van-field__label { + width 100px + text-align right + } + } +} \ No newline at end of file diff --git a/web/src/assets/css/task-list.styl b/web/src/assets/css/task-list.styl index 61d1a111..9717c90b 100644 --- a/web/src/assets/css/task-list.styl +++ b/web/src/assets/css/task-list.styl @@ -23,10 +23,6 @@ background-color: #47FFF1; } - .title-tabs .el-tabs__content { - padding: 10px 0; - } - .el-textarea { --el-input-focus-border-color: #47FFF1; } @@ -90,6 +86,14 @@ } } + .img-inline { + display flex + + .img-uploader { + margin-right 10px + } + } + .submit-btn { display flex margin: 20px 0 @@ -192,7 +196,7 @@ top 10px } - &:hover{ + &:hover { .remove { display block } diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index a1cd80ff..b8209755 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -167,8 +167,8 @@