From fd3a41bacb326f7d1b15b9b447efafced679d419 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 16:19:22 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AF=B7=E6=B1=82=E8=B6=85=E6=97=B6?=
=?UTF-8?q?=E5=A4=84=E7=90=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
dto/midjourney.go | 1 +
relay/relay-mj.go | 27 +++++++++++++--------------
web/src/components/MjLogsTable.js | 2 ++
web/src/pages/Channel/EditChannel.js | 6 ++++++
4 files changed, 22 insertions(+), 14 deletions(-)
diff --git a/dto/midjourney.go b/dto/midjourney.go
index a16a65e..4fef4e1 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -25,6 +25,7 @@ type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
CustomId string `json:"customId"`
+ BotType string `json:"botType"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index f667cd1..5fafc89 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -112,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
return nil
}
-func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
+func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
@@ -181,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "task_no_found",
}
}
- midjourneyTask := getMidjourneyTaskDto(c, originTask)
+ midjourneyTask := coverMidjourneyTaskDto(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &dto.MidjourneyResponse{
@@ -204,7 +204,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
- midjourneyTask := getMidjourneyTaskDto(c, originTask)
+ midjourneyTask := coverMidjourneyTaskDto(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
@@ -403,23 +403,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
-
+ timeout := time.Second * 30
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ // 使用带有超时的 context 创建新的请求
+ req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- //mjToken := ""
- //if c.Request.Header.Get("ApiKey") != "" {
- // mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
- //}
- //req.Header.Set("ApiKey", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header
- log.Printf("request header: %s", req.Header)
- log.Printf("request body: %s", midjRequest.Prompt)
+ //log.Printf("request header: %s", req.Header)
+ //log.Printf("request body: %s", midjRequest.Prompt)
+ defer cancel()
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return &dto.MidjourneyResponse{
- Code: 4,
+ Code: 5,
Description: "do_request_failed",
}
}
@@ -427,14 +426,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
err = req.Body.Close()
if err != nil {
return &dto.MidjourneyResponse{
- Code: 4,
+ Code: 5,
Description: "close_request_body_failed",
}
}
err = c.Request.Body.Close()
if err != nil {
return &dto.MidjourneyResponse{
- Code: 4,
+ Code: 5,
Description: "close_request_body_failed",
}
}
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 4f17c14..4accf54 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -35,6 +35,8 @@ function renderType(type) {
return 图生文;
case 'BLEAND':
return 图混合;
+ case 'REROLL':
+ return 重绘;
case 'INPAINT':
return 局部重绘;
case 'INPAINT_PRE':
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index 7221b9a..ee79368 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -95,6 +95,12 @@ const EditChannel = (props) => {
case 26:
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
break;
+ case 2:
+ localModels = ['midjourney'];
+ break;
+ case 5:
+ localModels = ['midjourney'];
+ break;
}
setInputs((inputs) => ({...inputs, models: localModels}));
}