feat: support InsightFace (close #60)

This commit is contained in:
CaIon 2024-03-14 18:08:12 +08:00
parent bc5a54df59
commit 9b5353a81a
9 changed files with 173 additions and 13 deletions

View File

@ -19,8 +19,9 @@
- mj_zoom (比例变焦) - mj_zoom (比例变焦)
- mj_shorten (提示词缩短) - mj_shorten (提示词缩短)
- mj_inpaint_pre (发起局部重绘必须和mj_inpaint一同添加) - mj_modal (窗口提交局部重绘和自定义比例变焦必须和mj_modal一同添加)
- mj_inpaint (局部重绘提交必须和mj_inpaint_pre一同添加) - mj_inpaint (局部重绘提交必须和mj_modal一同添加)
- mj_custom_zoom (自定义比例变焦必须和mj_modal一同添加)
- mj_high_variation (强变换) - mj_high_variation (强变换)
- mj_low_variation (弱变换) - mj_low_variation (弱变换)
- mj_pan (平移) - mj_pan (平移)
@ -32,13 +33,14 @@
"mj_variation": 0.1, "mj_variation": 0.1,
"mj_reroll": 0.1, "mj_reroll": 0.1,
"mj_blend": 0.1, "mj_blend": 0.1,
"mj_inpaint": 0.1, "mj_modal": 0.1,
"mj_zoom": 0.1, "mj_zoom": 0.1,
"mj_shorten": 0.1, "mj_shorten": 0.1,
"mj_high_variation": 0.1, "mj_high_variation": 0.1,
"mj_low_variation": 0.1, "mj_low_variation": 0.1,
"mj_pan": 0.1, "mj_pan": 0.1,
"mj_inpaint_pre": 0, "mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05, "mj_describe": 0.05,
"mj_upscale": 0.05, "mj_upscale": 0.05,
"swap_face": 0.05 "swap_face": 0.05

View File

@ -20,7 +20,7 @@ const (
MjActionHighVariation = "HIGH_VARIATION" MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION" MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN" MjActionPan = "PAN"
SwapFace = "SWAP_FACE" MjActionSwapFace = "SWAP_FACE"
) )
var MidjourneyModel2Action = map[string]string{ var MidjourneyModel2Action = map[string]string{
@ -38,5 +38,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_high_variation": MjActionHighVariation, "mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation, "mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan, "mj_pan": MjActionPan,
"swap_face": SwapFace, "swap_face": MjActionSwapFace,
} }

View File

@ -69,6 +69,8 @@ func RelayMidjourney(c *gin.Context) {
err = relay.RelayMidjourneyTask(c, relayMode) err = relay.RelayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed: case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c) err = relay.RelayMidjourneyTaskImageSeed(c)
case relayconstant.RelayModeSwapFace:
err = relay.RelaySwapFace(c)
default: default:
err = relay.RelayMidjourneySubmit(c, relayMode) err = relay.RelayMidjourneySubmit(c, relayMode)
} }

View File

@ -7,6 +7,11 @@ package dto
// Content string `json:"content"` // Content string `json:"content"`
//} //}
type SwapFaceRequest struct {
SourceBase64 string `json:"sourceBase64"`
TargetBase64 string `json:"targetBase64"`
}
type MidjourneyRequest struct { type MidjourneyRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
CustomId string `json:"customId"` CustomId string `json:"customId"`

View File

@ -25,6 +25,7 @@ const (
RelayModeMidjourneyAction RelayModeMidjourneyAction
RelayModeMidjourneyModal RelayModeMidjourneyModal
RelayModeMidjourneyShorten RelayModeMidjourneyShorten
RelayModeSwapFace
) )
func Path2RelayMode(path string) int { func Path2RelayMode(path string) int {
@ -64,6 +65,9 @@ func Path2RelayModeMidjourney(path string) int {
} else if strings.HasPrefix(path, "/mj/submit/shorten") { } else if strings.HasPrefix(path, "/mj/submit/shorten") {
// midjourney plus // midjourney plus
relayMode = RelayModeMidjourneyShorten relayMode = RelayModeMidjourneyShorten
} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
// midjourney plus
relayMode = RelayModeSwapFace
} else if strings.HasPrefix(path, "/mj/submit/imagine") { } else if strings.HasPrefix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine relayMode = RelayModeMidjourneyImagine
} else if strings.HasPrefix(path, "/mj/submit/blend") { } else if strings.HasPrefix(path, "/mj/submit/blend") {

View File

@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
return return
} }
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
startTime := time.Now().UnixNano() / int64(time.Millisecond)
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
var swapFaceRequest dto.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
}
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelPrice := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if modelPrice == -1 {
defaultPrice, ok := common.DefaultModelPrice[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: err.Error(),
}
}
quota := int(ratio * common.QuotaPerUnit)
if userQuota-quota < 0 {
return &dto.MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
}
}
requestURL := c.Request.URL.String()
baseURL := c.GetString("base_url")
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL)
if err != nil {
return &mjResp.Response
}
defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: constant.MjActionSwapFace,
MjId: midjResponse.Result,
Prompt: "swap_face",
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: startTime,
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
err = midjourneyTask.Insert()
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
}
c.Writer.WriteHeader(mjResp.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
taskId := c.Param("id") taskId := c.Param("id")
userId := c.GetInt("id") userId := c.GetInt("id")
@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil) midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
if err != nil { if err != nil {
return &midjResponseWithStatus.Response return &midjResponseWithStatus.Response
} }
//defer func(ctx context.Context) {
// err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
// if err != nil {
// common.SysError("error consuming token remain quota: " + err.Error())
// }
// err = model.CacheUpdateUserQuota(userId)
// if err != nil {
// common.SysError("error update user quota cache: " + err.Error())
// }
// if quota != 0 {
// tokenName := c.GetString("token_name")
// logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
// model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
// model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
// channelId := c.GetInt("channel_id")
// model.UpdateChannelUsedQuota(channelId, quota)
// }
//}(c.Request.Context())
midjResponse := &midjResponseWithStatus.Response midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse) respBody, err := json.Marshal(midjResponse)
@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} }
} }
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest) midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
if err != nil { if err != nil {
return &midjResponseWithStatus.Response return &midjResponseWithStatus.Response
} }
midjResponse := &midjResponseWithStatus.Response midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota { if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())

View File

@ -59,6 +59,7 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
} }
//relayMjRouter.Use() //relayMjRouter.Use()
} }

View File

@ -17,6 +17,9 @@ import (
func CoverActionToModelName(mjAction string) string { func CoverActionToModelName(mjAction string) string {
modelName := "mj_" + strings.ToLower(mjAction) modelName := "mj_" + strings.ToLower(mjAction)
if mjAction == constant.MjActionSwapFace {
modelName = "swap_face"
}
return modelName return modelName
} }
@ -43,6 +46,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
action = midjRequest.Action action = midjRequest.Action
case relayconstant.RelayModeMidjourneyModal: case relayconstant.RelayModeMidjourneyModal:
action = constant.MjActionModal action = constant.MjActionModal
case relayconstant.RelayModeSwapFace:
action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneySimpleChange: case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content) params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil { if params == nil {
@ -147,11 +152,25 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
return changeParams return changeParams
} }
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
var nullBytes []byte var nullBytes []byte
var requestBody io.Reader //var requestBody io.Reader
requestBody = c.Request.Body //requestBody = c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) // read request body to json, delete accountFilter and notifyHook
var mapResult map[string]interface{}
err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
delete(mapResult, "accountFilter")
delete(mapResult, "notifyHook")
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// make new request with mapResult
reqBody, err := json.Marshal(mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
if err != nil { if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
} }

View File

@ -53,6 +53,8 @@ function renderType(type) {
return <Tag color="teal" size='large'>自定义变焦-提交</Tag>; return <Tag color="teal" size='large'>自定义变焦-提交</Tag>;
case 'MODAL': case 'MODAL':
return <Tag color="green" size='large'>窗口处理</Tag>; return <Tag color="green" size='large'>窗口处理</Tag>;
case 'SWAP_FACE':
return <Tag color="light-green" size='large'>换脸</Tag>;
default: default:
return <Tag color="white" size='large'>未知</Tag>; return <Tag color="white" size='large'>未知</Tag>;
} }
@ -67,6 +69,8 @@ function renderCode(code) {
return <Tag color="lime" size='large'>等待中</Tag>; return <Tag color="lime" size='large'>等待中</Tag>;
case 22: case 22:
return <Tag color="orange" size='large'>重复提交</Tag>; return <Tag color="orange" size='large'>重复提交</Tag>;
case 0:
return <Tag color="yellow" size='large'>未提交</Tag>;
default: default:
return <Tag color="white" size='large'>未知</Tag>; return <Tag color="white" size='large'>未知</Tag>;
} }