feat: add blend and swapface task implements for midjourney

This commit is contained in:
RockYang
2024-01-25 18:50:24 +08:00
parent d772bbebe6
commit 24906a6df1
15 changed files with 375 additions and 157 deletions

View File

@@ -2,6 +2,7 @@ package mj
import (
"chatplus/core/types"
"errors"
"fmt"
"time"
@@ -33,7 +34,7 @@ func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *C
return &Client{client: client, Config: config, apiURL: apiURL, imgCdnURL: imgCdnURL}
}
func (c *Client) Imagine(prompt string) error {
func (c *Client) Imagine(task types.MjTask) error {
interactionsReq := &InteractionsRequest{
Type: 2,
ApplicationID: ApplicationID,
@@ -49,7 +50,7 @@ func (c *Client) Imagine(prompt string) error {
{
"type": 3,
"name": "prompt",
"value": prompt,
"value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt),
},
},
"application_command": map[string]any{
@@ -88,20 +89,28 @@ func (c *Client) Imagine(prompt string) error {
return nil
}
func (c *Client) Blend(task types.MjTask) error {
return errors.New("function not implemented")
}
func (c *Client) SwapFace(task types.MjTask) error {
return errors.New("function not implemented")
}
// Upscale 放大指定的图片
func (c *Client) Upscale(index int, messageId string, hash string) error {
func (c *Client) Upscale(task types.MjTask) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
MessageFlags: flags,
MessageID: task.MessageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}
@@ -120,19 +129,19 @@ func (c *Client) Upscale(index int, messageId string, hash string) error {
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(index int, messageId string, hash string) error {
func (c *Client) Variation(task types.MjTask) error {
flags := 0
interactionsReq := &InteractionsRequest{
Type: 3,
ApplicationID: ApplicationID,
GuildID: c.Config.GuildId,
ChannelID: c.Config.ChanelId,
MessageFlags: &flags,
MessageID: &messageId,
MessageFlags: flags,
MessageID: task.MessageId,
SessionID: SessionID,
Data: map[string]any{
"component_type": 2,
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
},
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
}

View File

@@ -3,8 +3,11 @@ package plus
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils"
"encoding/base64"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"github.com/imroc/req/v3"
@@ -22,9 +25,10 @@ func NewClient(config types.MidJourneyPlusConfig) *Client {
}
type ImageReq struct {
BotType string `json:"botType"`
Prompt string `json:"prompt"`
Base64Array []interface{} `json:"base64Array,omitempty"`
BotType string `json:"botType"`
Prompt string `json:"prompt,omitempty"`
Dimensions string `json:"dimensions,omitempty"`
Base64Array []string `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
@@ -49,12 +53,114 @@ type ErrRes struct {
} `json:"error"`
}
func (c *Client) Imagine(prompt string) (ImageRes, error) {
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.Config.ApiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
NotifyHook: c.Config.NotifyURL,
BotType: "MID_JOURNEY",
Prompt: task.Prompt,
NotifyHook: c.Config.NotifyURL,
Base64Array: make([]string, 1),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array[0] = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Blend 融图
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/blend", c.Config.ApiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
NotifyHook: c.Config.NotifyURL,
Base64Array: make([]string, 1),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array[0] = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("请求 API 出错:%v%v", err, string(errStr))
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-fast/mj/insight-face/swap", c.Config.ApiURL)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"notifyHook": c.Config.NotifyURL,
"state": "",
}
var res ImageRes
var errRes ErrRes
@@ -77,10 +183,10 @@ func (c *Client) Imagine(prompt string) (ImageRes, error) {
}
// Upscale 放大指定的图片
func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, error) {
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
"taskId": messageId,
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
"notifyHook": c.Config.NotifyURL,
}
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
@@ -104,10 +210,10 @@ func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, er
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(index int, messageId string, hash string) (ImageRes, error) {
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
"taskId": messageId,
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
"notifyHook": c.Config.NotifyURL,
}
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)

View File

@@ -69,18 +69,24 @@ func (s *Service) Run() {
var res ImageRes
switch task.Type {
case types.TaskImage:
index := strings.Index(task.Prompt, " ")
res, err = s.Client.Imagine(task.Prompt[index+1:])
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task.Index, task.MessageId, task.MessageHash)
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task.Index, task.MessageId, task.MessageHash)
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
logger.Error("绘画任务执行失败:", err)
logger.Error("绘画任务执行失败:", err, res.Description)
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
// 任务失败,通知前端

View File

@@ -65,14 +65,20 @@ func (s *Service) Run() {
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type {
case types.TaskImage:
err = s.client.Imagine(task.Prompt)
err = s.client.Imagine(task)
break
case types.TaskUpscale:
err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
err = s.client.Upscale(task)
break
case types.TaskVariation:
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
err = s.client.Variation(task)
break
case types.TaskBlend:
err = s.client.Blend(task)
break
case types.TaskSwapFace:
err = s.client.SwapFace(task)
break
}
if err != nil {

View File

@@ -8,8 +8,8 @@ const (
type InteractionsRequest struct {
Type int `json:"type"`
ApplicationID string `json:"application_id"`
MessageFlags *int `json:"message_flags,omitempty"`
MessageID *string `json:"message_id,omitempty"`
MessageFlags int `json:"message_flags,omitempty"`
MessageID string `json:"message_id,omitempty"`
GuildID string `json:"guild_id"`
ChannelID string `json:"channel_id"`
SessionID string `json:"session_id"`