From 2ad591411eff3f7f1ef91cc012e25ee915dea550 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 17:46:34 +0800 Subject: [PATCH] feat: support shorten --- Midjourney.md | 287 ++---------------------------- constant/midjourney.go | 1 + controller/midjourney.go | 4 + controller/relay.go | 3 + dto/midjourney.go | 40 +++-- model/midjourney.go | 1 + relay/constant/relay_mode.go | 1 + relay/relay-mj.go | 58 +++--- router/relay-router.go | 1 + web/src/components/MjLogsTable.js | 2 + 10 files changed, 74 insertions(+), 324 deletions(-) diff --git a/Midjourney.md b/Midjourney.md index fe4d433..becc9c9 100644 --- a/Midjourney.md +++ b/Midjourney.md @@ -7,285 +7,28 @@ ```json { "gpt-4-gizmo-*": 0.1, - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_describe": 0.05, - "mj_upscale": 0.05 + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_inpaint": 0.1, + "mj_zoom": 0.1, + "mj_inpaint_pre": 0, + "mj_describe": 0.05, + "mj_upscale": 0.05, + "swap_face": 0.05 } ``` ## 渠道设置 -### 对接 midjourney-proxy +### 对接 midjourney-proxy(plus) 1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy) -2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney +2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face 3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080 4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填 ### 对接上游new api -1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney -2. 地址填写上游new api的地址,例如:http://localhost:8080 -3. 密钥填写上游new api的密钥 - -## 任务提交 - -### 绘图变化 - -**接口地址**:`/mj/submit/change` - -**请求方式**:`POST` - -**请求数据类型**:`application/json` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求示例**: - -```javascript -{ - "action" -: - "UPSCALE", - "index" -: - 1, - "notifyHook" -: - "", - "state" -: - "", - "taskId" -: - "1320098173412546" -} -``` - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------| -| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 | -|   action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | | -|   index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | | -|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | | -|   state | 自定义参数 | | false | string | | -|   taskId | 任务ID | | true | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 提交结果 | -| 201 | Created | | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|-------------------------------------------|----------------|----------------| -| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) | -| description | 描述 | string | | -| properties | 扩展字段 | object | | -| result | 任务ID | string | | - -**响应示例**: - -```javascript -{ - "code" -: - 1, - "description" -: - "提交成功", - "properties" -: - { - } -, - "result" -: - 1320098173412546 -} -``` - -### 提交Imagine任务 - -**接口地址**:`/mj/submit/imagine` - -**请求方式**:`POST` - -**请求数据类型**:`application/json` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求示例**: - -```javascript -{ - "base64" -: - "", - "notifyHook" -: - "", - "prompt" -: - "Cat", - "state" -: - "" -} -``` - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------------------------|-------------------------|------|-------|-------------|-------------| -| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 | -|   base64 | 垫图base64 | | false | string | | -|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | | -|   prompt | 提示词 | | true | string | | -|   state | 自定义参数 | | false | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 提交结果 | -| 201 | Created | | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|-------------------------------------------|----------------|----------------| -| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) | -| description | 描述 | string | | -| properties | 扩展字段 | object | | -| result | 任务ID | string | | - -**响应示例**: - -```javascript -{ - "code" -: - 1, - "description" -: - "提交成功", - "properties" -: - { - } -, - "result" -: - 1320098173412546 -} -``` - -## 任务查询 - -### 指定ID获取任务 - -**接口地址**:`/mj/task/{id}/fetch` - -**请求方式**:`GET` - -**请求数据类型**:`application/x-www-form-urlencoded` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------|------|------|-------|--------|--------| -| id | 任务ID | path | false | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 任务 | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|----------------------------------------------------------|----------------|----------------| -| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | | -| description | 任务描述 | string | | -| failReason | 失败原因 | string | | -| finishTime | 结束时间 | integer(int64) | integer(int64) | -| id | 任务ID | string | | -| imageUrl | 图片url | string | | -| progress | 任务进度 | string | | -| prompt | 提示词 | string | | -| promptEn | 提示词-英文 | string | | -| startTime | 开始执行时间 | integer(int64) | integer(int64) | -| state | 自定义参数 | string | | -| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | | -| submitTime | 提交时间 | integer(int64) | integer(int64) | - -**响应示例**: - -```javascript -{ - "action" -: - "", - "description" -: - "", - "failReason" -: - "", - "finishTime" -: - 0, - "id" -: - "", - "imageUrl" -: - "", - "progress" -: - "", - "prompt" -: - "", - "promptEn" -: - "", - "startTime" -: - 0, - "state" -: - "", - "status" -: - "", - "submitTime" -: - 0 -} -``` \ No newline at end of file +1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face +2. 地址填写上游new api的地址,例如:http://localhost:3000 +3. 密钥填写上游new api的密钥 \ No newline at end of file diff --git a/constant/midjourney.go b/constant/midjourney.go index c184435..a5bccb7 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -14,4 +14,5 @@ const ( MjActionInPaint = "INPAINT" MjActionInPaintPre = "INPAINT_PRE" MjActionZoom = "ZOOM" + MjActionShorten = "SHORTEN" ) diff --git a/controller/midjourney.go b/controller/midjourney.go index cac253c..b666e91 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -263,6 +263,10 @@ func UpdateMidjourneyTaskBulk() { task.ImageUrl = responseItem.ImageUrl task.Status = responseItem.Status task.FailReason = responseItem.FailReason + if responseItem.Properties != nil { + propertiesStr, _ := json.Marshal(responseItem.Properties) + task.Properties = string(propertiesStr) + } if responseItem.Buttons != nil { buttonStr, _ := json.Marshal(responseItem.Buttons) task.Buttons = string(buttonStr) diff --git a/controller/relay.go b/controller/relay.go index a42db2e..7652840 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -68,6 +68,9 @@ func RelayMidjourney(c *gin.Context) { } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") { // midjourney plus relayMode = relayconstant.RelayModeMidjourneyModal + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/shorten") { + // midjourney plus + relayMode = relayconstant.RelayModeMidjourneyShorten } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { relayMode = relayconstant.RelayModeMidjourneyImagine } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") { diff --git a/dto/midjourney.go b/dto/midjourney.go index 4fef4e1..d3b19d5 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -22,23 +22,24 @@ type MidjourneyResponse struct { } type MidjourneyDto struct { - MjId string `json:"id"` - Action string `json:"action"` - CustomId string `json:"customId"` - BotType string `json:"botType"` - Prompt string `json:"prompt"` - PromptEn string `json:"promptEn"` - Description string `json:"description"` - State string `json:"state"` - SubmitTime int64 `json:"submitTime"` - StartTime int64 `json:"startTime"` - FinishTime int64 `json:"finishTime"` - ImageUrl string `json:"imageUrl"` - Status string `json:"status"` - Progress string `json:"progress"` - FailReason string `json:"failReason"` - Buttons any `json:"buttons"` - MaskBase64 string `json:"maskBase64"` + MjId string `json:"id"` + Action string `json:"action"` + CustomId string `json:"customId"` + BotType string `json:"botType"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + ImageUrl string `json:"imageUrl"` + Status string `json:"status"` + Progress string `json:"progress"` + FailReason string `json:"failReason"` + Buttons any `json:"buttons"` + MaskBase64 string `json:"maskBase64"` + Properties *Properties `json:"properties"` } type MidjourneyStatus struct { @@ -70,3 +71,8 @@ type ActionButton struct { Type any `json:"type"` Style any `json:"style"` } + +type Properties struct { + FinalPrompt string `json:"finalPrompt"` + FinalZhPrompt string `json:"finalZhPrompt"` +} diff --git a/model/midjourney.go b/model/midjourney.go index f20ab32..dd065a3 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -20,6 +20,7 @@ type Midjourney struct { ChannelId int `json:"channel_id"` Quota int `json:"quota"` Buttons string `json:"buttons"` + Properties string `json:"properties"` } // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index c49caae..d8dc7ee 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -23,6 +23,7 @@ const ( RelayModeAudioTranslation RelayModeMidjourneyAction RelayModeMidjourneyModal + RelayModeMidjourneyShorten ) func Path2RelayMode(path string) int { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index d582055..a1f6ed4 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -31,6 +31,7 @@ var DefaultModelPrice = map[string]float64{ "mj_inpaint_pre": 0, "mj_describe": 0.05, "mj_upscale": 0.05, + "swap_face": 0.05, } func RelayMidjourneyImage(c *gin.Context) { @@ -140,6 +141,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.Buttons = buttons } } + if originTask.Properties != "" { + var properties dto.Properties + err := json.Unmarshal([]byte(originTask.Properties), &properties) + if err == nil { + midjourneyTask.Properties = &properties + } + } return } @@ -260,9 +268,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if midjRequest.Prompt == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } - midjRequest.Action = "IMAGINE" + midjRequest.Action = constant.MjActionImagine } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 - midjRequest.Action = "DESCRIBE" + midjRequest.Action = constant.MjActionDescribe + } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only + midjRequest.Action = constant.MjActionShorten } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = "BLEND" } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 @@ -292,7 +302,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") } mjId = midjRequest.TaskId - midjRequest.Action = "INPAINT" + midjRequest.Action = constant.MjActionInPaint } originTask := model.GetByMJId(userId, mjId) @@ -418,25 +428,16 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons defer cancel() resp, err := service.GetHttpClient().Do(req) if err != nil { - return &dto.MidjourneyResponse{ - Code: 5, - Description: "do_request_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed") } err = req.Body.Close() if err != nil { - return &dto.MidjourneyResponse{ - Code: 5, - Description: "close_request_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed") } err = c.Request.Body.Close() if err != nil { - return &dto.MidjourneyResponse{ - Code: 5, - Description: "close_request_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed") } var midjResponse dto.MidjourneyResponse @@ -464,33 +465,20 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons responseBody, err := io.ReadAll(resp.Body) if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "read_response_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed") } err = resp.Body.Close() if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "close_response_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed") + } + if resp.StatusCode != 200 { + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status") } - err = json.Unmarshal(responseBody, &midjResponse) log.Printf("responseBody: %s", string(responseBody)) log.Printf("midjResponse: %v", midjResponse) - if resp.StatusCode != 200 { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode), - } - } if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "unmarshal_response_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed") } // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md @@ -651,7 +639,7 @@ func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj } else if strings.Contains(action, "pan") { midjRequest.Action = constant.MjActionVariation midjRequest.Index = 1 - } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") { + } else if action == "Outpaint" || action == "CustomZoom" { midjRequest.Action = constant.MjActionZoom midjRequest.Index = 1 } else if action == "Inpaint" { diff --git a/router/relay-router.go b/router/relay-router.go index 68b762b..f572d8f 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -48,6 +48,7 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { relayMjRouter.POST("/submit/action", controller.RelayMidjourney) + relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney) relayMjRouter.POST("/submit/modal", controller.RelayMidjourney) relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) relayMjRouter.POST("/submit/change", controller.RelayMidjourney) diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index a1ffeb6..fe6554e 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -35,6 +35,8 @@ function renderType(type) { return 图生文; case 'BLEAND': return 图混合; + case 'SHORTEN': + return 缩词; case 'REROLL': return 重绘; case 'INPAINT':