feat: blend and swap face function for midjourney-plus is ready

This commit is contained in:
RockYang
2024-01-26 11:57:08 +08:00
parent dea72738c1
commit a0f3bc8ccb
7 changed files with 568 additions and 504 deletions

View File

@@ -157,6 +157,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Prompt: prompt,
CreatedAt: time.Now(),
}
if data.TaskType == types.TaskBlend.String() {
data.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
} else if data.TaskType == types.TaskSwapFace.String() {
data.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
@@ -166,7 +171,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Id: int(job.Id),
TaskId: taskId,
SessionId: data.SessionId,
Type: types.TaskImage,
Type: types.TaskType(data.TaskType),
Prompt: prompt,
UserId: userId,
ImgArr: data.ImgArr,

View File

@@ -98,7 +98,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
NotifyHook: c.Config.NotifyURL,
Base64Array: make([]string, 1),
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
@@ -107,7 +107,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array[0] = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}

View File

@@ -167,7 +167,7 @@ func (p *ServicePool) HasAvailableService() bool {
}
func (p *ServicePool) Notify(data plus.CBReq) error {
logger.Infof("收到任务回调:%+v", data)
logger.Debugf("收到任务回调:%+v", data)
var job model.MidJourneyJob
res := p.db.Where("task_id = ?", data.Id).First(&job)
if res.Error != nil {
@@ -190,7 +190,7 @@ func (p *ServicePool) SyncTaskProgress() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&items)
res := p.db.Where("progress >= ? AND progress < ?", 0, 100).Find(&items)
if res.Error != nil {
continue
}
@@ -215,6 +215,11 @@ func (p *ServicePool) SyncTaskProgress() {
if err != nil {
continue
}
// 任务失败了
if task.FailReason != "" {
p.db.Model(&model.MidJourneyJob{Id: v.Id}).UpdateColumn("progress", -1)
continue
}
if len(task.Buttons) > 0 {
v.Hash = getImageHash(task.Buttons[0].CustomId)
}