From 6ab1b3a524703c95a787471116f2b98a49d67dc5 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 22:17:03 +0800 Subject: [PATCH] fix: fix image-seed error --- relay/relay-mj.go | 18 ------------------ service/midjourney.go | 35 +++++++++++++++++++++-------------- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 35353b4..3cd42cb 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -266,24 +266,6 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { if err != nil { return &midjResponseWithStatus.Response } - //defer func(ctx context.Context) { - // err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) - // if err != nil { - // common.SysError("error consuming token remain quota: " + err.Error()) - // } - // err = model.CacheUpdateUserQuota(userId) - // if err != nil { - // common.SysError("error update user quota cache: " + err.Error()) - // } - // if quota != 0 { - // tokenName := c.GetString("token_name") - // logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action) - // model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) - // model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - // channelId := c.GetInt("channel_id") - // model.UpdateChannelUsedQuota(channelId, quota) - // } - //}(c.Request.Context()) midjResponse := &midjResponseWithStatus.Response c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) respBody, err := json.Marshal(midjResponse) diff --git a/service/midjourney.go b/service/midjourney.go index 3ab967f..698e100 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -158,16 +158,19 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU //requestBody = c.Request.Body // read request body to json, delete accountFilter and notifyHook var mapResult map[string]interface{} - err := json.NewDecoder(c.Request.Body).Decode(&mapResult) - if err != nil { - return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err + // if get request, no need to read request body + if c.Request.Method != "GET" { + err := json.NewDecoder(c.Request.Body).Decode(&mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + delete(mapResult, "accountFilter") + if !constant.MjNotifyEnabled { + delete(mapResult, "notifyHook") + } + //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + // make new request with mapResult } - delete(mapResult, "accountFilter") - if !constant.MjNotifyEnabled { - delete(mapResult, "notifyHook") - } - //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - // make new request with mapResult reqBody, err := json.Marshal(mapResult) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err @@ -209,11 +212,15 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err } - - err = json.Unmarshal(responseBody, &midjResponse) - log.Printf("responseBody: %s", string(responseBody)) - if err != nil { - return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + respStr := string(responseBody) + log.Printf("responseBody: %s", respStr) + if respStr == "" { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil + } else { + err = json.Unmarshal(responseBody, &midjResponse) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + } } //log.Printf("midjResponse: %v", midjResponse) //for k, v := range resp.Header {