diff --git a/Midjourney.md b/Midjourney.md index d495e84..5733a11 100644 --- a/Midjourney.md +++ b/Midjourney.md @@ -19,8 +19,9 @@ - mj_zoom (比例变焦) - mj_shorten (提示词缩短) -- mj_inpaint_pre (发起局部重绘,必须和mj_inpaint一同添加) -- mj_inpaint (局部重绘提交,必须和mj_inpaint_pre一同添加) +- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加) +- mj_inpaint (局部重绘提交,必须和mj_modal一同添加) +- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加) - mj_high_variation (强变换) - mj_low_variation (弱变换) - mj_pan (平移) @@ -32,13 +33,14 @@ "mj_variation": 0.1, "mj_reroll": 0.1, "mj_blend": 0.1, - "mj_inpaint": 0.1, + "mj_modal": 0.1, "mj_zoom": 0.1, "mj_shorten": 0.1, "mj_high_variation": 0.1, "mj_low_variation": 0.1, "mj_pan": 0.1, - "mj_inpaint_pre": 0, + "mj_inpaint": 0, + "mj_custom_zoom": 0, "mj_describe": 0.05, "mj_upscale": 0.05, "swap_face": 0.05 diff --git a/constant/midjourney.go b/constant/midjourney.go index f4ae2e4..3d321ca 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -20,7 +20,7 @@ const ( MjActionHighVariation = "HIGH_VARIATION" MjActionLowVariation = "LOW_VARIATION" MjActionPan = "PAN" - SwapFace = "SWAP_FACE" + MjActionSwapFace = "SWAP_FACE" ) var MidjourneyModel2Action = map[string]string{ @@ -38,5 +38,5 @@ var MidjourneyModel2Action = map[string]string{ "mj_high_variation": MjActionHighVariation, "mj_low_variation": MjActionLowVariation, "mj_pan": MjActionPan, - "swap_face": SwapFace, + "swap_face": MjActionSwapFace, } diff --git a/controller/relay.go b/controller/relay.go index e31679d..9f866b8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -69,6 +69,8 @@ func RelayMidjourney(c *gin.Context) { err = relay.RelayMidjourneyTask(c, relayMode) case relayconstant.RelayModeMidjourneyTaskImageSeed: err = relay.RelayMidjourneyTaskImageSeed(c) + case relayconstant.RelayModeSwapFace: + err = relay.RelaySwapFace(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } diff --git a/dto/midjourney.go b/dto/midjourney.go index d3c3583..c675f7e 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -7,6 +7,11 @@ package dto // Content string `json:"content"` //} +type SwapFaceRequest struct { + SourceBase64 string `json:"sourceBase64"` + TargetBase64 string `json:"targetBase64"` +} + type MidjourneyRequest struct { Prompt string `json:"prompt"` CustomId string `json:"customId"` diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 197efdc..1790c57 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -25,6 +25,7 @@ const ( RelayModeMidjourneyAction RelayModeMidjourneyModal RelayModeMidjourneyShorten + RelayModeSwapFace ) func Path2RelayMode(path string) int { @@ -64,6 +65,9 @@ func Path2RelayModeMidjourney(path string) int { } else if strings.HasPrefix(path, "/mj/submit/shorten") { // midjourney plus relayMode = RelayModeMidjourneyShorten + } else if strings.HasPrefix(path, "/mj/insight-face/swap") { + // midjourney plus + relayMode = RelayModeSwapFace } else if strings.HasPrefix(path, "/mj/submit/imagine") { relayMode = RelayModeMidjourneyImagine } else if strings.HasPrefix(path, "/mj/submit/blend") { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 6185d5b..01ae0c8 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } +func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { + startTime := time.Now().UnixNano() / int64(time.Millisecond) + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + group := c.GetString("group") + channelId := c.GetInt("channel_id") + var swapFaceRequest dto.SwapFaceRequest + err := common.UnmarshalBodyReusable(c, &swapFaceRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") + } + if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") + } + modelName := service.CoverActionToModelName(constant.MjActionSwapFace) + modelPrice := common.GetModelPrice(modelName, true) + // 如果没有配置价格,则使用默认价格 + if modelPrice == -1 { + defaultPrice, ok := common.DefaultModelPrice[modelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + groupRatio := common.GetGroupRatio(group) + ratio := modelPrice * groupRatio + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return &dto.MidjourneyResponse{ + Code: 4, + Description: err.Error(), + } + } + quota := int(ratio * common.QuotaPerUnit) + + if userQuota-quota < 0 { + return &dto.MidjourneyResponse{ + Code: 4, + Description: "quota_not_enough", + } + } + requestURL := c.Request.URL.String() + baseURL := c.GetString("base_url") + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL) + if err != nil { + return &mjResp.Response + } + defer func(ctx context.Context) { + if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { + err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + } + }(c.Request.Context()) + midjResponse := &mjResp.Response + midjourneyTask := &model.Midjourney{ + UserId: userId, + Code: midjResponse.Code, + Action: constant.MjActionSwapFace, + MjId: midjResponse.Result, + Prompt: "swap_face", + PromptEn: "", + Description: midjResponse.Description, + State: "", + SubmitTime: startTime, + StartTime: time.Now().UnixNano() / int64(time.Millisecond), + FinishTime: 0, + ImageUrl: "", + Status: "", + Progress: "0%", + FailReason: "", + ChannelId: c.GetInt("channel_id"), + Quota: quota, + } + err = midjourneyTask.Insert() + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed") + } + c.Writer.WriteHeader(mjResp.StatusCode) + respBody, err := json.Marshal(midjResponse) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") + } + return nil +} + func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { taskId := c.Param("id") userId := c.GetInt("id") @@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { requestURL := c.Request.URL.String() fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) - midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil) + midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) if err != nil { return &midjResponseWithStatus.Response } + //defer func(ctx context.Context) { + // err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) + // if err != nil { + // common.SysError("error consuming token remain quota: " + err.Error()) + // } + // err = model.CacheUpdateUserQuota(userId) + // if err != nil { + // common.SysError("error update user quota cache: " + err.Error()) + // } + // if quota != 0 { + // tokenName := c.GetString("token_name") + // logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action) + // model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) + // model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + // channelId := c.GetInt("channel_id") + // model.UpdateChannelUsedQuota(channelId, quota) + // } + //}(c.Request.Context()) midjResponse := &midjResponseWithStatus.Response c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) respBody, err := json.Marshal(midjResponse) @@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } - midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest) + midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) if err != nil { return &midjResponseWithStatus.Response } midjResponse := &midjResponseWithStatus.Response defer func(ctx context.Context) { - if consumeQuota { + if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) diff --git a/router/relay-router.go b/router/relay-router.go index 3c6910a..4addee0 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -59,6 +59,7 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) + relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney) } //relayMjRouter.Use() } diff --git a/service/midjourney.go b/service/midjourney.go index 17e54e1..7c47cd6 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -17,6 +17,9 @@ import ( func CoverActionToModelName(mjAction string) string { modelName := "mj_" + strings.ToLower(mjAction) + if mjAction == constant.MjActionSwapFace { + modelName = "swap_face" + } return modelName } @@ -43,6 +46,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin action = midjRequest.Action case relayconstant.RelayModeMidjourneyModal: action = constant.MjActionModal + case relayconstant.RelayModeSwapFace: + action = constant.MjActionSwapFace case relayconstant.RelayModeMidjourneySimpleChange: params := ConvertSimpleChangeParams(midjRequest.Content) if params == nil { @@ -147,11 +152,25 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { return changeParams } -func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { +func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { var nullBytes []byte - var requestBody io.Reader - requestBody = c.Request.Body - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + //var requestBody io.Reader + //requestBody = c.Request.Body + // read request body to json, delete accountFilter and notifyHook + var mapResult map[string]interface{} + err := json.NewDecoder(c.Request.Body).Decode(&mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + delete(mapResult, "accountFilter") + delete(mapResult, "notifyHook") + //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + // make new request with mapResult + reqBody, err := json.Marshal(mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody))) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err } diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 4843b2f..90c55f1 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -53,6 +53,8 @@ function renderType(type) { return 自定义变焦-提交; case 'MODAL': return 窗口处理; + case 'SWAP_FACE': + return 换脸; default: return 未知; } @@ -67,6 +69,8 @@ function renderCode(code) { return 等待中; case 22: return 重复提交; + case 0: + return 未提交; default: return 未知; }