feat: support image-seed (close #86)

This commit is contained in:
CaIon 2024-03-14 16:59:46 +08:00
parent d704902b70
commit bc5a54df59

View File

@ -139,27 +139,38 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
} }
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
//taskId := c.Param("id") taskId := c.Param("id")
//userId := c.GetInt("id") userId := c.GetInt("id")
//originTask := model.GetByMJId(userId, taskId) originTask := model.GetByMJId(userId, taskId)
//if originTask == nil { if originTask == nil {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
//} }
//channel, err := model.GetChannelById(originTask.ChannelId, false) channel, err := model.GetChannelById(originTask.ChannelId, true)
//if err != nil { if err != nil {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
//} }
//if channel.Status != common.ChannelStatusEnabled { if channel.Status != common.ChannelStatusEnabled {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
//} }
//c.Set("channel_id", originTask.ChannelId) c.Set("channel_id", originTask.ChannelId)
//requestURL := c.Request.URL.String() c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
//fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) requestURL := c.Request.URL.String()
//if err != nil { fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed") midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
//} if err != nil {
log.Println("RelayMidjourneyTaskImageSeed") return &midjResponseWithStatus.Response
}
midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
}
return nil return nil
} }
@ -297,7 +308,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理 } else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false) channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil { if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
} }
@ -306,6 +317,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} }
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId) c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
log.Printf("检测到此操作为放大、变换、重绘获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) log.Printf("检测到此操作为放大、变换、重绘获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
} }
midjRequest.Prompt = originTask.Prompt midjRequest.Prompt = originTask.Prompt