feat: support image-seed (close #86)

This commit is contained in:
CaIon 2024-03-14 16:59:46 +08:00
parent d704902b70
commit bc5a54df59

View File

@ -139,27 +139,38 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
}
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
//taskId := c.Param("id")
//userId := c.GetInt("id")
//originTask := model.GetByMJId(userId, taskId)
//if originTask == nil {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
//}
//channel, err := model.GetChannelById(originTask.ChannelId, false)
//if err != nil {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
//}
//if channel.Status != common.ChannelStatusEnabled {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
//}
//c.Set("channel_id", originTask.ChannelId)
//requestURL := c.Request.URL.String()
//fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
//if err != nil {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
//}
log.Println("RelayMidjourneyTaskImageSeed")
taskId := c.Param("id")
userId := c.GetInt("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
}
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
}
if channel.Status != common.ChannelStatusEnabled {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
}
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
requestURL := c.Request.URL.String()
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
if err != nil {
return &midjResponseWithStatus.Response
}
midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
}
return nil
}
@ -297,7 +308,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false)
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
}
@ -306,6 +317,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
log.Printf("检测到此操作为放大、变换、重绘获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt