From fd3a41bacb326f7d1b15b9b447efafced679d419 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 16:19:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AF=B7=E6=B1=82=E8=B6=85=E6=97=B6?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/midjourney.go | 1 + relay/relay-mj.go | 27 +++++++++++++-------------- web/src/components/MjLogsTable.js | 2 ++ web/src/pages/Channel/EditChannel.js | 6 ++++++ 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/dto/midjourney.go b/dto/midjourney.go index a16a65e..4fef4e1 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -25,6 +25,7 @@ type MidjourneyDto struct { MjId string `json:"id"` Action string `json:"action"` CustomId string `json:"customId"` + BotType string `json:"botType"` Prompt string `json:"prompt"` PromptEn string `json:"promptEn"` Description string `json:"description"` diff --git a/relay/relay-mj.go b/relay/relay-mj.go index f667cd1..5fafc89 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -112,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { return nil } -func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { +func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { midjourneyTask.MjId = originTask.MjId midjourneyTask.Progress = originTask.Progress midjourneyTask.PromptEn = originTask.PromptEn @@ -181,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse Description: "task_no_found", } } - midjourneyTask := getMidjourneyTaskDto(c, originTask) + midjourneyTask := coverMidjourneyTaskDto(c, originTask) respBody, err = json.Marshal(midjourneyTask) if err != nil { return &dto.MidjourneyResponse{ @@ -204,7 +204,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse if len(condition.IDs) != 0 { originTasks := model.GetByMJIds(userId, condition.IDs) for _, originTask := range originTasks { - midjourneyTask := getMidjourneyTaskDto(c, originTask) + midjourneyTask := coverMidjourneyTaskDto(c, originTask) tasks = append(tasks, midjourneyTask) } } @@ -403,23 +403,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey")) - + timeout := time.Second * 30 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - //mjToken := "" - //if c.Request.Header.Get("ApiKey") != "" { - // mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1] - //} - //req.Header.Set("ApiKey", "Bearer midjourney-proxy") req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) // print request header - log.Printf("request header: %s", req.Header) - log.Printf("request body: %s", midjRequest.Prompt) + //log.Printf("request header: %s", req.Header) + //log.Printf("request body: %s", midjRequest.Prompt) + defer cancel() resp, err := service.GetHttpClient().Do(req) if err != nil { return &dto.MidjourneyResponse{ - Code: 4, + Code: 5, Description: "do_request_failed", } } @@ -427,14 +426,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons err = req.Body.Close() if err != nil { return &dto.MidjourneyResponse{ - Code: 4, + Code: 5, Description: "close_request_body_failed", } } err = c.Request.Body.Close() if err != nil { return &dto.MidjourneyResponse{ - Code: 4, + Code: 5, Description: "close_request_body_failed", } } diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 4f17c14..4accf54 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -35,6 +35,8 @@ function renderType(type) { return 图生文; case 'BLEAND': return 图混合; + case 'REROLL': + return 重绘; case 'INPAINT': return 局部重绘; case 'INPAINT_PRE': diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 7221b9a..ee79368 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -95,6 +95,12 @@ const EditChannel = (props) => { case 26: localModels = ['glm-4', 'glm-4v', 'glm-3-turbo']; break; + case 2: + localModels = ['midjourney']; + break; + case 5: + localModels = ['midjourney']; + break; } setInputs((inputs) => ({...inputs, models: localModels})); }