From bc5a54df59ea464186390b95aa56cfbd5c605d55 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 16:59:46 +0800 Subject: [PATCH] feat: support image-seed (close #86) --- relay/relay-mj.go | 56 ++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 8eebaeb..6185d5b 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -139,27 +139,38 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo } func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { - //taskId := c.Param("id") - //userId := c.GetInt("id") - //originTask := model.GetByMJId(userId, taskId) - //if originTask == nil { - // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") - //} - //channel, err := model.GetChannelById(originTask.ChannelId, false) - //if err != nil { - // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") - //} - //if channel.Status != common.ChannelStatusEnabled { - // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") - //} - //c.Set("channel_id", originTask.ChannelId) - //requestURL := c.Request.URL.String() - //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) - //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) - //if err != nil { - // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed") - //} - log.Println("RelayMidjourneyTaskImageSeed") + taskId := c.Param("id") + userId := c.GetInt("id") + originTask := model.GetByMJId(userId, taskId) + if originTask == nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") + } + channel, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") + } + if channel.Status != common.ChannelStatusEnabled { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") + } + c.Set("channel_id", originTask.ChannelId) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + + requestURL := c.Request.URL.String() + fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) + midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil) + if err != nil { + return &midjResponseWithStatus.Response + } + midjResponse := &midjResponseWithStatus.Response + c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) + respBody, err := json.Marshal(midjResponse) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") + } return nil } @@ -297,7 +308,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 - channel, err := model.GetChannelById(originTask.ChannelId, false) + channel, err := model.GetChannelById(originTask.ChannelId, true) if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") } @@ -306,6 +317,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } c.Set("base_url", channel.GetBaseURL()) c.Set("channel_id", originTask.ChannelId) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt