mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-19 00:46:37 +08:00
feat: support image-seed (close #86)
This commit is contained in:
parent
d704902b70
commit
bc5a54df59
@ -139,27 +139,38 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||||
//taskId := c.Param("id")
|
taskId := c.Param("id")
|
||||||
//userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
//originTask := model.GetByMJId(userId, taskId)
|
originTask := model.GetByMJId(userId, taskId)
|
||||||
//if originTask == nil {
|
if originTask == nil {
|
||||||
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
|
||||||
//}
|
}
|
||||||
//channel, err := model.GetChannelById(originTask.ChannelId, false)
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||||
//if err != nil {
|
if err != nil {
|
||||||
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||||
//}
|
}
|
||||||
//if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||||||
//}
|
}
|
||||||
//c.Set("channel_id", originTask.ChannelId)
|
c.Set("channel_id", originTask.ChannelId)
|
||||||
//requestURL := c.Request.URL.String()
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
//fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
|
||||||
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
|
requestURL := c.Request.URL.String()
|
||||||
//if err != nil {
|
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
||||||
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
|
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
|
||||||
//}
|
if err != nil {
|
||||||
log.Println("RelayMidjourneyTaskImageSeed")
|
return &midjResponseWithStatus.Response
|
||||||
|
}
|
||||||
|
midjResponse := &midjResponseWithStatus.Response
|
||||||
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||||
|
respBody, err := json.Marshal(midjResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||||
|
}
|
||||||
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,7 +308,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||||
channel, err := model.GetChannelById(originTask.ChannelId, false)
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||||
}
|
}
|
||||||
@ -306,6 +317,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
c.Set("channel_id", originTask.ChannelId)
|
c.Set("channel_id", originTask.ChannelId)
|
||||||
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||||
}
|
}
|
||||||
midjRequest.Prompt = originTask.Prompt
|
midjRequest.Prompt = originTask.Prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user