mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-20 09:16:37 +08:00
feat: support InsightFace (close #60)
This commit is contained in:
parent
bc5a54df59
commit
9b5353a81a
@ -19,8 +19,9 @@
|
|||||||
|
|
||||||
- mj_zoom (比例变焦)
|
- mj_zoom (比例变焦)
|
||||||
- mj_shorten (提示词缩短)
|
- mj_shorten (提示词缩短)
|
||||||
- mj_inpaint_pre (发起局部重绘,必须和mj_inpaint一同添加)
|
- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
|
||||||
- mj_inpaint (局部重绘提交,必须和mj_inpaint_pre一同添加)
|
- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
|
||||||
|
- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
|
||||||
- mj_high_variation (强变换)
|
- mj_high_variation (强变换)
|
||||||
- mj_low_variation (弱变换)
|
- mj_low_variation (弱变换)
|
||||||
- mj_pan (平移)
|
- mj_pan (平移)
|
||||||
@ -32,13 +33,14 @@
|
|||||||
"mj_variation": 0.1,
|
"mj_variation": 0.1,
|
||||||
"mj_reroll": 0.1,
|
"mj_reroll": 0.1,
|
||||||
"mj_blend": 0.1,
|
"mj_blend": 0.1,
|
||||||
"mj_inpaint": 0.1,
|
"mj_modal": 0.1,
|
||||||
"mj_zoom": 0.1,
|
"mj_zoom": 0.1,
|
||||||
"mj_shorten": 0.1,
|
"mj_shorten": 0.1,
|
||||||
"mj_high_variation": 0.1,
|
"mj_high_variation": 0.1,
|
||||||
"mj_low_variation": 0.1,
|
"mj_low_variation": 0.1,
|
||||||
"mj_pan": 0.1,
|
"mj_pan": 0.1,
|
||||||
"mj_inpaint_pre": 0,
|
"mj_inpaint": 0,
|
||||||
|
"mj_custom_zoom": 0,
|
||||||
"mj_describe": 0.05,
|
"mj_describe": 0.05,
|
||||||
"mj_upscale": 0.05,
|
"mj_upscale": 0.05,
|
||||||
"swap_face": 0.05
|
"swap_face": 0.05
|
||||||
|
@ -20,7 +20,7 @@ const (
|
|||||||
MjActionHighVariation = "HIGH_VARIATION"
|
MjActionHighVariation = "HIGH_VARIATION"
|
||||||
MjActionLowVariation = "LOW_VARIATION"
|
MjActionLowVariation = "LOW_VARIATION"
|
||||||
MjActionPan = "PAN"
|
MjActionPan = "PAN"
|
||||||
SwapFace = "SWAP_FACE"
|
MjActionSwapFace = "SWAP_FACE"
|
||||||
)
|
)
|
||||||
|
|
||||||
var MidjourneyModel2Action = map[string]string{
|
var MidjourneyModel2Action = map[string]string{
|
||||||
@ -38,5 +38,5 @@ var MidjourneyModel2Action = map[string]string{
|
|||||||
"mj_high_variation": MjActionHighVariation,
|
"mj_high_variation": MjActionHighVariation,
|
||||||
"mj_low_variation": MjActionLowVariation,
|
"mj_low_variation": MjActionLowVariation,
|
||||||
"mj_pan": MjActionPan,
|
"mj_pan": MjActionPan,
|
||||||
"swap_face": SwapFace,
|
"swap_face": MjActionSwapFace,
|
||||||
}
|
}
|
||||||
|
@ -69,6 +69,8 @@ func RelayMidjourney(c *gin.Context) {
|
|||||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
err = relay.RelayMidjourneyTask(c, relayMode)
|
||||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
err = relay.RelayMidjourneyTaskImageSeed(c)
|
||||||
|
case relayconstant.RelayModeSwapFace:
|
||||||
|
err = relay.RelaySwapFace(c)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
err = relay.RelayMidjourneySubmit(c, relayMode)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,11 @@ package dto
|
|||||||
// Content string `json:"content"`
|
// Content string `json:"content"`
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
type SwapFaceRequest struct {
|
||||||
|
SourceBase64 string `json:"sourceBase64"`
|
||||||
|
TargetBase64 string `json:"targetBase64"`
|
||||||
|
}
|
||||||
|
|
||||||
type MidjourneyRequest struct {
|
type MidjourneyRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
CustomId string `json:"customId"`
|
CustomId string `json:"customId"`
|
||||||
|
@ -25,6 +25,7 @@ const (
|
|||||||
RelayModeMidjourneyAction
|
RelayModeMidjourneyAction
|
||||||
RelayModeMidjourneyModal
|
RelayModeMidjourneyModal
|
||||||
RelayModeMidjourneyShorten
|
RelayModeMidjourneyShorten
|
||||||
|
RelayModeSwapFace
|
||||||
)
|
)
|
||||||
|
|
||||||
func Path2RelayMode(path string) int {
|
func Path2RelayMode(path string) int {
|
||||||
@ -64,6 +65,9 @@ func Path2RelayModeMidjourney(path string) int {
|
|||||||
} else if strings.HasPrefix(path, "/mj/submit/shorten") {
|
} else if strings.HasPrefix(path, "/mj/submit/shorten") {
|
||||||
// midjourney plus
|
// midjourney plus
|
||||||
relayMode = RelayModeMidjourneyShorten
|
relayMode = RelayModeMidjourneyShorten
|
||||||
|
} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = RelayModeSwapFace
|
||||||
} else if strings.HasPrefix(path, "/mj/submit/imagine") {
|
} else if strings.HasPrefix(path, "/mj/submit/imagine") {
|
||||||
relayMode = RelayModeMidjourneyImagine
|
relayMode = RelayModeMidjourneyImagine
|
||||||
} else if strings.HasPrefix(path, "/mj/submit/blend") {
|
} else if strings.HasPrefix(path, "/mj/submit/blend") {
|
||||||
|
@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||||
|
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
group := c.GetString("group")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
var swapFaceRequest dto.SwapFaceRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||||
|
}
|
||||||
|
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||||
|
}
|
||||||
|
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||||
|
modelPrice := common.GetModelPrice(modelName, true)
|
||||||
|
// 如果没有配置价格,则使用默认价格
|
||||||
|
if modelPrice == -1 {
|
||||||
|
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||||
|
if !ok {
|
||||||
|
modelPrice = 0.1
|
||||||
|
} else {
|
||||||
|
modelPrice = defaultPrice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
groupRatio := common.GetGroupRatio(group)
|
||||||
|
ratio := modelPrice * groupRatio
|
||||||
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
return &dto.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
quota := int(ratio * common.QuotaPerUnit)
|
||||||
|
|
||||||
|
if userQuota-quota < 0 {
|
||||||
|
return &dto.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "quota_not_enough",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requestURL := c.Request.URL.String()
|
||||||
|
baseURL := c.GetString("base_url")
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL)
|
||||||
|
if err != nil {
|
||||||
|
return &mjResp.Response
|
||||||
|
}
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
|
||||||
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(c.Request.Context())
|
||||||
|
midjResponse := &mjResp.Response
|
||||||
|
midjourneyTask := &model.Midjourney{
|
||||||
|
UserId: userId,
|
||||||
|
Code: midjResponse.Code,
|
||||||
|
Action: constant.MjActionSwapFace,
|
||||||
|
MjId: midjResponse.Result,
|
||||||
|
Prompt: "swap_face",
|
||||||
|
PromptEn: "",
|
||||||
|
Description: midjResponse.Description,
|
||||||
|
State: "",
|
||||||
|
SubmitTime: startTime,
|
||||||
|
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||||
|
FinishTime: 0,
|
||||||
|
ImageUrl: "",
|
||||||
|
Status: "",
|
||||||
|
Progress: "0%",
|
||||||
|
FailReason: "",
|
||||||
|
ChannelId: c.GetInt("channel_id"),
|
||||||
|
Quota: quota,
|
||||||
|
}
|
||||||
|
err = midjourneyTask.Insert()
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(mjResp.StatusCode)
|
||||||
|
respBody, err := json.Marshal(midjResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||||
|
}
|
||||||
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||||
taskId := c.Param("id")
|
taskId := c.Param("id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
|
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
||||||
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
|
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &midjResponseWithStatus.Response
|
return &midjResponseWithStatus.Response
|
||||||
}
|
}
|
||||||
|
//defer func(ctx context.Context) {
|
||||||
|
// err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||||
|
// if err != nil {
|
||||||
|
// common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
// }
|
||||||
|
// err = model.CacheUpdateUserQuota(userId)
|
||||||
|
// if err != nil {
|
||||||
|
// common.SysError("error update user quota cache: " + err.Error())
|
||||||
|
// }
|
||||||
|
// if quota != 0 {
|
||||||
|
// tokenName := c.GetString("token_name")
|
||||||
|
// logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||||
|
// model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||||
|
// model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
// channelId := c.GetInt("channel_id")
|
||||||
|
// model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
// }
|
||||||
|
//}(c.Request.Context())
|
||||||
midjResponse := &midjResponseWithStatus.Response
|
midjResponse := &midjResponseWithStatus.Response
|
||||||
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||||
respBody, err := json.Marshal(midjResponse)
|
respBody, err := json.Marshal(midjResponse)
|
||||||
@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
|
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &midjResponseWithStatus.Response
|
return &midjResponseWithStatus.Response
|
||||||
}
|
}
|
||||||
midjResponse := &midjResponseWithStatus.Response
|
midjResponse := &midjResponseWithStatus.Response
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if consumeQuota {
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
@ -59,6 +59,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
|
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
|
||||||
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
|
||||||
}
|
}
|
||||||
//relayMjRouter.Use()
|
//relayMjRouter.Use()
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,9 @@ import (
|
|||||||
|
|
||||||
func CoverActionToModelName(mjAction string) string {
|
func CoverActionToModelName(mjAction string) string {
|
||||||
modelName := "mj_" + strings.ToLower(mjAction)
|
modelName := "mj_" + strings.ToLower(mjAction)
|
||||||
|
if mjAction == constant.MjActionSwapFace {
|
||||||
|
modelName = "swap_face"
|
||||||
|
}
|
||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,6 +46,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
|
|||||||
action = midjRequest.Action
|
action = midjRequest.Action
|
||||||
case relayconstant.RelayModeMidjourneyModal:
|
case relayconstant.RelayModeMidjourneyModal:
|
||||||
action = constant.MjActionModal
|
action = constant.MjActionModal
|
||||||
|
case relayconstant.RelayModeSwapFace:
|
||||||
|
action = constant.MjActionSwapFace
|
||||||
case relayconstant.RelayModeMidjourneySimpleChange:
|
case relayconstant.RelayModeMidjourneySimpleChange:
|
||||||
params := ConvertSimpleChangeParams(midjRequest.Content)
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
if params == nil {
|
if params == nil {
|
||||||
@ -147,11 +152,25 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
|
|||||||
return changeParams
|
return changeParams
|
||||||
}
|
}
|
||||||
|
|
||||||
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
|
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
|
||||||
var nullBytes []byte
|
var nullBytes []byte
|
||||||
var requestBody io.Reader
|
//var requestBody io.Reader
|
||||||
requestBody = c.Request.Body
|
//requestBody = c.Request.Body
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
// read request body to json, delete accountFilter and notifyHook
|
||||||
|
var mapResult map[string]interface{}
|
||||||
|
err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
delete(mapResult, "accountFilter")
|
||||||
|
delete(mapResult, "notifyHook")
|
||||||
|
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
// make new request with mapResult
|
||||||
|
reqBody, err := json.Marshal(mapResult)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
}
|
}
|
||||||
|
@ -53,6 +53,8 @@ function renderType(type) {
|
|||||||
return <Tag color="teal" size='large'>自定义变焦-提交</Tag>;
|
return <Tag color="teal" size='large'>自定义变焦-提交</Tag>;
|
||||||
case 'MODAL':
|
case 'MODAL':
|
||||||
return <Tag color="green" size='large'>窗口处理</Tag>;
|
return <Tag color="green" size='large'>窗口处理</Tag>;
|
||||||
|
case 'SWAP_FACE':
|
||||||
|
return <Tag color="light-green" size='large'>换脸</Tag>;
|
||||||
default:
|
default:
|
||||||
return <Tag color="white" size='large'>未知</Tag>;
|
return <Tag color="white" size='large'>未知</Tag>;
|
||||||
}
|
}
|
||||||
@ -67,6 +69,8 @@ function renderCode(code) {
|
|||||||
return <Tag color="lime" size='large'>等待中</Tag>;
|
return <Tag color="lime" size='large'>等待中</Tag>;
|
||||||
case 22:
|
case 22:
|
||||||
return <Tag color="orange" size='large'>重复提交</Tag>;
|
return <Tag color="orange" size='large'>重复提交</Tag>;
|
||||||
|
case 0:
|
||||||
|
return <Tag color="yellow" size='large'>未提交</Tag>;
|
||||||
default:
|
default:
|
||||||
return <Tag color="white" size='large'>未知</Tag>;
|
return <Tag color="white" size='large'>未知</Tag>;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user