From 2ad591411eff3f7f1ef91cc012e25ee915dea550 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 17:46:34 +0800
Subject: [PATCH] feat: support shorten
---
Midjourney.md | 287 ++----------------------------
constant/midjourney.go | 1 +
controller/midjourney.go | 4 +
controller/relay.go | 3 +
dto/midjourney.go | 40 +++--
model/midjourney.go | 1 +
relay/constant/relay_mode.go | 1 +
relay/relay-mj.go | 58 +++---
router/relay-router.go | 1 +
web/src/components/MjLogsTable.js | 2 +
10 files changed, 74 insertions(+), 324 deletions(-)
diff --git a/Midjourney.md b/Midjourney.md
index fe4d433..becc9c9 100644
--- a/Midjourney.md
+++ b/Midjourney.md
@@ -7,285 +7,28 @@
```json
{
"gpt-4-gizmo-*": 0.1,
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_describe": 0.05,
- "mj_upscale": 0.05
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_inpaint": 0.1,
+ "mj_zoom": 0.1,
+ "mj_inpaint_pre": 0,
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
+ "swap_face": 0.05
}
```
## 渠道设置
-### 对接 midjourney-proxy
+### 对接 midjourney-proxy(plus)
1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
-2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
+2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
### 对接上游new api
-1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
-2. 地址填写上游new api的地址,例如:http://localhost:8080
-3. 密钥填写上游new api的密钥
-
-## 任务提交
-
-### 绘图变化
-
-**接口地址**:`/mj/submit/change`
-
-**请求方式**:`POST`
-
-**请求数据类型**:`application/json`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求示例**:
-
-```javascript
-{
- "action"
-:
- "UPSCALE",
- "index"
-:
- 1,
- "notifyHook"
-:
- "",
- "state"
-:
- "",
- "taskId"
-:
- "1320098173412546"
-}
-```
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
-| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
-| action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
-| index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
-| notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
-| state | 自定义参数 | | false | string | |
-| taskId | 任务ID | | true | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 提交结果 |
-| 201 | Created | |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|-------------------------------------------|----------------|----------------|
-| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
-| description | 描述 | string | |
-| properties | 扩展字段 | object | |
-| result | 任务ID | string | |
-
-**响应示例**:
-
-```javascript
-{
- "code"
-:
- 1,
- "description"
-:
- "提交成功",
- "properties"
-:
- {
- }
-,
- "result"
-:
- 1320098173412546
-}
-```
-
-### 提交Imagine任务
-
-**接口地址**:`/mj/submit/imagine`
-
-**请求方式**:`POST`
-
-**请求数据类型**:`application/json`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求示例**:
-
-```javascript
-{
- "base64"
-:
- "",
- "notifyHook"
-:
- "",
- "prompt"
-:
- "Cat",
- "state"
-:
- ""
-}
-```
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------------------------|-------------------------|------|-------|-------------|-------------|
-| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
-| base64 | 垫图base64 | | false | string | |
-| notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
-| prompt | 提示词 | | true | string | |
-| state | 自定义参数 | | false | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 提交结果 |
-| 201 | Created | |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|-------------------------------------------|----------------|----------------|
-| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
-| description | 描述 | string | |
-| properties | 扩展字段 | object | |
-| result | 任务ID | string | |
-
-**响应示例**:
-
-```javascript
-{
- "code"
-:
- 1,
- "description"
-:
- "提交成功",
- "properties"
-:
- {
- }
-,
- "result"
-:
- 1320098173412546
-}
-```
-
-## 任务查询
-
-### 指定ID获取任务
-
-**接口地址**:`/mj/task/{id}/fetch`
-
-**请求方式**:`GET`
-
-**请求数据类型**:`application/x-www-form-urlencoded`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------|------|------|-------|--------|--------|
-| id | 任务ID | path | false | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 任务 |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|----------------------------------------------------------|----------------|----------------|
-| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
-| description | 任务描述 | string | |
-| failReason | 失败原因 | string | |
-| finishTime | 结束时间 | integer(int64) | integer(int64) |
-| id | 任务ID | string | |
-| imageUrl | 图片url | string | |
-| progress | 任务进度 | string | |
-| prompt | 提示词 | string | |
-| promptEn | 提示词-英文 | string | |
-| startTime | 开始执行时间 | integer(int64) | integer(int64) |
-| state | 自定义参数 | string | |
-| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
-| submitTime | 提交时间 | integer(int64) | integer(int64) |
-
-**响应示例**:
-
-```javascript
-{
- "action"
-:
- "",
- "description"
-:
- "",
- "failReason"
-:
- "",
- "finishTime"
-:
- 0,
- "id"
-:
- "",
- "imageUrl"
-:
- "",
- "progress"
-:
- "",
- "prompt"
-:
- "",
- "promptEn"
-:
- "",
- "startTime"
-:
- 0,
- "state"
-:
- "",
- "status"
-:
- "",
- "submitTime"
-:
- 0
-}
-```
\ No newline at end of file
+1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
+2. 地址填写上游new api的地址,例如:http://localhost:3000
+3. 密钥填写上游new api的密钥
\ No newline at end of file
diff --git a/constant/midjourney.go b/constant/midjourney.go
index c184435..a5bccb7 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -14,4 +14,5 @@ const (
MjActionInPaint = "INPAINT"
MjActionInPaintPre = "INPAINT_PRE"
MjActionZoom = "ZOOM"
+ MjActionShorten = "SHORTEN"
)
diff --git a/controller/midjourney.go b/controller/midjourney.go
index cac253c..b666e91 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -263,6 +263,10 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
+ if responseItem.Properties != nil {
+ propertiesStr, _ := json.Marshal(responseItem.Properties)
+ task.Properties = string(propertiesStr)
+ }
if responseItem.Buttons != nil {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
diff --git a/controller/relay.go b/controller/relay.go
index a42db2e..7652840 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -68,6 +68,9 @@ func RelayMidjourney(c *gin.Context) {
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") {
// midjourney plus
relayMode = relayconstant.RelayModeMidjourneyModal
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/shorten") {
+ // midjourney plus
+ relayMode = relayconstant.RelayModeMidjourneyShorten
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = relayconstant.RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
diff --git a/dto/midjourney.go b/dto/midjourney.go
index 4fef4e1..d3b19d5 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -22,23 +22,24 @@ type MidjourneyResponse struct {
}
type MidjourneyDto struct {
- MjId string `json:"id"`
- Action string `json:"action"`
- CustomId string `json:"customId"`
- BotType string `json:"botType"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"promptEn"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submitTime"`
- StartTime int64 `json:"startTime"`
- FinishTime int64 `json:"finishTime"`
- ImageUrl string `json:"imageUrl"`
- Status string `json:"status"`
- Progress string `json:"progress"`
- FailReason string `json:"failReason"`
- Buttons any `json:"buttons"`
- MaskBase64 string `json:"maskBase64"`
+ MjId string `json:"id"`
+ Action string `json:"action"`
+ CustomId string `json:"customId"`
+ BotType string `json:"botType"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"promptEn"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submitTime"`
+ StartTime int64 `json:"startTime"`
+ FinishTime int64 `json:"finishTime"`
+ ImageUrl string `json:"imageUrl"`
+ Status string `json:"status"`
+ Progress string `json:"progress"`
+ FailReason string `json:"failReason"`
+ Buttons any `json:"buttons"`
+ MaskBase64 string `json:"maskBase64"`
+ Properties *Properties `json:"properties"`
}
type MidjourneyStatus struct {
@@ -70,3 +71,8 @@ type ActionButton struct {
Type any `json:"type"`
Style any `json:"style"`
}
+
+type Properties struct {
+ FinalPrompt string `json:"finalPrompt"`
+ FinalZhPrompt string `json:"finalZhPrompt"`
+}
diff --git a/model/midjourney.go b/model/midjourney.go
index f20ab32..dd065a3 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -20,6 +20,7 @@ type Midjourney struct {
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
Buttons string `json:"buttons"`
+ Properties string `json:"properties"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index c49caae..d8dc7ee 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -23,6 +23,7 @@ const (
RelayModeAudioTranslation
RelayModeMidjourneyAction
RelayModeMidjourneyModal
+ RelayModeMidjourneyShorten
)
func Path2RelayMode(path string) int {
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index d582055..a1f6ed4 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -31,6 +31,7 @@ var DefaultModelPrice = map[string]float64{
"mj_inpaint_pre": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
+ "swap_face": 0.05,
}
func RelayMidjourneyImage(c *gin.Context) {
@@ -140,6 +141,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Buttons = buttons
}
}
+ if originTask.Properties != "" {
+ var properties dto.Properties
+ err := json.Unmarshal([]byte(originTask.Properties), &properties)
+ if err == nil {
+ midjourneyTask.Properties = &properties
+ }
+ }
return
}
@@ -260,9 +268,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if midjRequest.Prompt == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
}
- midjRequest.Action = "IMAGINE"
+ midjRequest.Action = constant.MjActionImagine
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
- midjRequest.Action = "DESCRIBE"
+ midjRequest.Action = constant.MjActionDescribe
+ } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
+ midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND"
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
@@ -292,7 +302,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
}
mjId = midjRequest.TaskId
- midjRequest.Action = "INPAINT"
+ midjRequest.Action = constant.MjActionInPaint
}
originTask := model.GetByMJId(userId, mjId)
@@ -418,25 +428,16 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer cancel()
resp, err := service.GetHttpClient().Do(req)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 5,
- Description: "do_request_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed")
}
err = req.Body.Close()
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 5,
- Description: "close_request_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
}
err = c.Request.Body.Close()
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 5,
- Description: "close_request_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
}
var midjResponse dto.MidjourneyResponse
@@ -464,33 +465,20 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "read_response_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed")
}
err = resp.Body.Close()
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "close_response_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed")
+ }
+ if resp.StatusCode != 200 {
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status")
}
-
err = json.Unmarshal(responseBody, &midjResponse)
log.Printf("responseBody: %s", string(responseBody))
log.Printf("midjResponse: %v", midjResponse)
- if resp.StatusCode != 200 {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
- }
- }
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "unmarshal_response_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed")
}
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
@@ -651,7 +639,7 @@ func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
} else if strings.Contains(action, "pan") {
midjRequest.Action = constant.MjActionVariation
midjRequest.Index = 1
- } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") {
+ } else if action == "Outpaint" || action == "CustomZoom" {
midjRequest.Action = constant.MjActionZoom
midjRequest.Index = 1
} else if action == "Inpaint" {
diff --git a/router/relay-router.go b/router/relay-router.go
index 68b762b..f572d8f 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -48,6 +48,7 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index a1ffeb6..fe6554e 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -35,6 +35,8 @@ function renderType(type) {
return 图生文;
case 'BLEAND':
return 图混合;
+ case 'SHORTEN':
+ return 缩词;
case 'REROLL':
return 重绘;
case 'INPAINT':