diff --git a/Midjourney.md b/Midjourney.md
index d495e84..5733a11 100644
--- a/Midjourney.md
+++ b/Midjourney.md
@@ -19,8 +19,9 @@
- mj_zoom (比例变焦)
- mj_shorten (提示词缩短)
-- mj_inpaint_pre (发起局部重绘,必须和mj_inpaint一同添加)
-- mj_inpaint (局部重绘提交,必须和mj_inpaint_pre一同添加)
+- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
+- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
+- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
- mj_high_variation (强变换)
- mj_low_variation (弱变换)
- mj_pan (平移)
@@ -32,13 +33,14 @@
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
- "mj_inpaint": 0.1,
+ "mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
- "mj_inpaint_pre": 0,
+ "mj_inpaint": 0,
+ "mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05
diff --git a/constant/midjourney.go b/constant/midjourney.go
index f4ae2e4..3d321ca 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -20,7 +20,7 @@ const (
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
- SwapFace = "SWAP_FACE"
+ MjActionSwapFace = "SWAP_FACE"
)
var MidjourneyModel2Action = map[string]string{
@@ -38,5 +38,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
- "swap_face": SwapFace,
+ "swap_face": MjActionSwapFace,
}
diff --git a/controller/relay.go b/controller/relay.go
index e31679d..9f866b8 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -69,6 +69,8 @@ func RelayMidjourney(c *gin.Context) {
err = relay.RelayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c)
+ case relayconstant.RelayModeSwapFace:
+ err = relay.RelaySwapFace(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
diff --git a/dto/midjourney.go b/dto/midjourney.go
index d3c3583..c675f7e 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -7,6 +7,11 @@ package dto
// Content string `json:"content"`
//}
+type SwapFaceRequest struct {
+ SourceBase64 string `json:"sourceBase64"`
+ TargetBase64 string `json:"targetBase64"`
+}
+
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index 197efdc..1790c57 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -25,6 +25,7 @@ const (
RelayModeMidjourneyAction
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
+ RelayModeSwapFace
)
func Path2RelayMode(path string) int {
@@ -64,6 +65,9 @@ func Path2RelayModeMidjourney(path string) int {
} else if strings.HasPrefix(path, "/mj/submit/shorten") {
// midjourney plus
relayMode = RelayModeMidjourneyShorten
+ } else if strings.HasPrefix(path, "/mj/insight-face/swap") {
+ // midjourney plus
+ relayMode = RelayModeSwapFace
} else if strings.HasPrefix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasPrefix(path, "/mj/submit/blend") {
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 6185d5b..01ae0c8 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
return
}
+func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
+ startTime := time.Now().UnixNano() / int64(time.Millisecond)
+ tokenId := c.GetInt("token_id")
+ userId := c.GetInt("id")
+ group := c.GetString("group")
+ channelId := c.GetInt("channel_id")
+ var swapFaceRequest dto.SwapFaceRequest
+ err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+ if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
+ }
+ modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
+ modelPrice := common.GetModelPrice(modelName, true)
+ // 如果没有配置价格,则使用默认价格
+ if modelPrice == -1 {
+ defaultPrice, ok := common.DefaultModelPrice[modelName]
+ if !ok {
+ modelPrice = 0.1
+ } else {
+ modelPrice = defaultPrice
+ }
+ }
+ groupRatio := common.GetGroupRatio(group)
+ ratio := modelPrice * groupRatio
+ userQuota, err := model.CacheGetUserQuota(userId)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: err.Error(),
+ }
+ }
+ quota := int(ratio * common.QuotaPerUnit)
+
+ if userQuota-quota < 0 {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "quota_not_enough",
+ }
+ }
+ requestURL := c.Request.URL.String()
+ baseURL := c.GetString("base_url")
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+ mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL)
+ if err != nil {
+ return &mjResp.Response
+ }
+ defer func(ctx context.Context) {
+ if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.SysError("error update user quota cache: " + err.Error())
+ }
+ if quota != 0 {
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ channelId := c.GetInt("channel_id")
+ model.UpdateChannelUsedQuota(channelId, quota)
+ }
+ }
+ }(c.Request.Context())
+ midjResponse := &mjResp.Response
+ midjourneyTask := &model.Midjourney{
+ UserId: userId,
+ Code: midjResponse.Code,
+ Action: constant.MjActionSwapFace,
+ MjId: midjResponse.Result,
+ Prompt: "swap_face",
+ PromptEn: "",
+ Description: midjResponse.Description,
+ State: "",
+ SubmitTime: startTime,
+ StartTime: time.Now().UnixNano() / int64(time.Millisecond),
+ FinishTime: 0,
+ ImageUrl: "",
+ Status: "",
+ Progress: "0%",
+ FailReason: "",
+ ChannelId: c.GetInt("channel_id"),
+ Quota: quota,
+ }
+ err = midjourneyTask.Insert()
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
+ }
+ c.Writer.WriteHeader(mjResp.StatusCode)
+ respBody, err := json.Marshal(midjResponse)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+ }
+ return nil
+}
+
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
taskId := c.Param("id")
userId := c.GetInt("id")
@@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
requestURL := c.Request.URL.String()
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
- midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
+ midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
if err != nil {
return &midjResponseWithStatus.Response
}
+ //defer func(ctx context.Context) {
+ // err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+ // if err != nil {
+ // common.SysError("error consuming token remain quota: " + err.Error())
+ // }
+ // err = model.CacheUpdateUserQuota(userId)
+ // if err != nil {
+ // common.SysError("error update user quota cache: " + err.Error())
+ // }
+ // if quota != 0 {
+ // tokenName := c.GetString("token_name")
+ // logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
+ // model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+ // model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ // channelId := c.GetInt("channel_id")
+ // model.UpdateChannelUsedQuota(channelId, quota)
+ // }
+ //}(c.Request.Context())
midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse)
@@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
- midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
if err != nil {
return &midjResponseWithStatus.Response
}
midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
- if consumeQuota {
+ if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
diff --git a/router/relay-router.go b/router/relay-router.go
index 3c6910a..4addee0 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -59,6 +59,7 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
+ relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
}
//relayMjRouter.Use()
}
diff --git a/service/midjourney.go b/service/midjourney.go
index 17e54e1..7c47cd6 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -17,6 +17,9 @@ import (
func CoverActionToModelName(mjAction string) string {
modelName := "mj_" + strings.ToLower(mjAction)
+ if mjAction == constant.MjActionSwapFace {
+ modelName = "swap_face"
+ }
return modelName
}
@@ -43,6 +46,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
action = midjRequest.Action
case relayconstant.RelayModeMidjourneyModal:
action = constant.MjActionModal
+ case relayconstant.RelayModeSwapFace:
+ action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
@@ -147,11 +152,25 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
return changeParams
}
-func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
var nullBytes []byte
- var requestBody io.Reader
- requestBody = c.Request.Body
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ //var requestBody io.Reader
+ //requestBody = c.Request.Body
+ // read request body to json, delete accountFilter and notifyHook
+ var mapResult map[string]interface{}
+ err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ delete(mapResult, "accountFilter")
+ delete(mapResult, "notifyHook")
+ //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ // make new request with mapResult
+ reqBody, err := json.Marshal(mapResult)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
}
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 4843b2f..90c55f1 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -53,6 +53,8 @@ function renderType(type) {
return 自定义变焦-提交;
case 'MODAL':
return 窗口处理;
+ case 'SWAP_FACE':
+ return 换脸;
default:
return 未知;
}
@@ -67,6 +69,8 @@ function renderCode(code) {
return 等待中;
case 22:
return 重复提交;
+ case 0:
+ return 未提交;
default:
return 未知;
}