mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-18 00:16:37 +08:00
Merge pull request #114 from Calcium-Ion/midjourney-proxy-plus
feat: support midjourney-proxy-plus
This commit is contained in:
commit
f62dcbf669
312
Midjourney.md
312
Midjourney.md
@ -4,288 +4,64 @@
|
|||||||
|
|
||||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
||||||
|
|
||||||
|
### 模型列表
|
||||||
|
|
||||||
|
### midjourney-proxy支持
|
||||||
|
|
||||||
|
- mj_imagine (绘图)
|
||||||
|
- mj_variation (变换)
|
||||||
|
- mj_reroll (重绘)
|
||||||
|
- mj_blend (混合)
|
||||||
|
- mj_upscale (放大)
|
||||||
|
- mj_describe (图生文)
|
||||||
|
|
||||||
|
### 仅midjourney-proxy-plus支持
|
||||||
|
|
||||||
|
- mj_zoom (比例变焦)
|
||||||
|
- mj_shorten (提示词缩短)
|
||||||
|
- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
|
||||||
|
- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
|
||||||
|
- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
|
||||||
|
- mj_high_variation (强变换)
|
||||||
|
- mj_low_variation (弱变换)
|
||||||
|
- mj_pan (平移)
|
||||||
|
- swap_face (换脸)
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"gpt-4-gizmo-*": 0.1,
|
|
||||||
"mj_imagine": 0.1,
|
"mj_imagine": 0.1,
|
||||||
"mj_variation": 0.1,
|
"mj_variation": 0.1,
|
||||||
"mj_reroll": 0.1,
|
"mj_reroll": 0.1,
|
||||||
"mj_blend": 0.1,
|
"mj_blend": 0.1,
|
||||||
|
"mj_modal": 0.1,
|
||||||
|
"mj_zoom": 0.1,
|
||||||
|
"mj_shorten": 0.1,
|
||||||
|
"mj_high_variation": 0.1,
|
||||||
|
"mj_low_variation": 0.1,
|
||||||
|
"mj_pan": 0.1,
|
||||||
|
"mj_inpaint": 0,
|
||||||
|
"mj_custom_zoom": 0,
|
||||||
"mj_describe": 0.05,
|
"mj_describe": 0.05,
|
||||||
"mj_upscale": 0.05
|
"mj_upscale": 0.05,
|
||||||
|
"swap_face": 0.05
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 渠道设置
|
## 渠道设置
|
||||||
|
|
||||||
### 对接 midjourney-proxy
|
### 对接 midjourney-proxy(plus)
|
||||||
1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
|
|
||||||
2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
|
1.
|
||||||
|
|
||||||
|
部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
|
||||||
|
|
||||||
|
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
|
||||||
|
,模型选择midjourney,如果有换脸模型,可以选择swap_face
|
||||||
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
|
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
|
||||||
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
|
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
|
||||||
|
|
||||||
### 对接上游new api
|
### 对接上游new api
|
||||||
1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
|
|
||||||
2. 地址填写上游new api的地址,例如:http://localhost:8080
|
|
||||||
3. 密钥填写上游new api的密钥
|
|
||||||
|
|
||||||
## 任务提交
|
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
|
||||||
|
2. 地址填写上游new api的地址,例如:http://localhost:3000
|
||||||
### 绘图变化
|
3. 密钥填写上游new api的密钥
|
||||||
|
|
||||||
**接口地址**:`/mj/submit/change`
|
|
||||||
|
|
||||||
**请求方式**:`POST`
|
|
||||||
|
|
||||||
**请求数据类型**:`application/json`
|
|
||||||
|
|
||||||
**响应数据类型**:`*/*`
|
|
||||||
|
|
||||||
**接口描述**:
|
|
||||||
|
|
||||||
**请求示例**:
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
{
|
|
||||||
"action"
|
|
||||||
:
|
|
||||||
"UPSCALE",
|
|
||||||
"index"
|
|
||||||
:
|
|
||||||
1,
|
|
||||||
"notifyHook"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"state"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"taskId"
|
|
||||||
:
|
|
||||||
"1320098173412546"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**请求参数**:
|
|
||||||
|
|
||||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
|
||||||
|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
|
|
||||||
| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
|
|
||||||
|   action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
|
|
||||||
|   index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
|
|
||||||
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|
|
||||||
|   state | 自定义参数 | | false | string | |
|
|
||||||
|   taskId | 任务ID | | true | string | |
|
|
||||||
|
|
||||||
**响应状态**:
|
|
||||||
|
|
||||||
| 状态码 | 说明 | schema |
|
|
||||||
|-----|--------------|--------|
|
|
||||||
| 200 | OK | 提交结果 |
|
|
||||||
| 201 | Created | |
|
|
||||||
| 401 | Unauthorized | |
|
|
||||||
| 403 | Forbidden | |
|
|
||||||
| 404 | Not Found | |
|
|
||||||
|
|
||||||
**响应参数**:
|
|
||||||
|
|
||||||
| 参数名称 | 参数说明 | 类型 | schema |
|
|
||||||
|-------------|-------------------------------------------|----------------|----------------|
|
|
||||||
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
|
|
||||||
| description | 描述 | string | |
|
|
||||||
| properties | 扩展字段 | object | |
|
|
||||||
| result | 任务ID | string | |
|
|
||||||
|
|
||||||
**响应示例**:
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
{
|
|
||||||
"code"
|
|
||||||
:
|
|
||||||
1,
|
|
||||||
"description"
|
|
||||||
:
|
|
||||||
"提交成功",
|
|
||||||
"properties"
|
|
||||||
:
|
|
||||||
{
|
|
||||||
}
|
|
||||||
,
|
|
||||||
"result"
|
|
||||||
:
|
|
||||||
1320098173412546
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 提交Imagine任务
|
|
||||||
|
|
||||||
**接口地址**:`/mj/submit/imagine`
|
|
||||||
|
|
||||||
**请求方式**:`POST`
|
|
||||||
|
|
||||||
**请求数据类型**:`application/json`
|
|
||||||
|
|
||||||
**响应数据类型**:`*/*`
|
|
||||||
|
|
||||||
**接口描述**:
|
|
||||||
|
|
||||||
**请求示例**:
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
{
|
|
||||||
"base64"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"notifyHook"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"prompt"
|
|
||||||
:
|
|
||||||
"Cat",
|
|
||||||
"state"
|
|
||||||
:
|
|
||||||
""
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**请求参数**:
|
|
||||||
|
|
||||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
|
||||||
|------------------------|-------------------------|------|-------|-------------|-------------|
|
|
||||||
| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
|
|
||||||
|   base64 | 垫图base64 | | false | string | |
|
|
||||||
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|
|
||||||
|   prompt | 提示词 | | true | string | |
|
|
||||||
|   state | 自定义参数 | | false | string | |
|
|
||||||
|
|
||||||
**响应状态**:
|
|
||||||
|
|
||||||
| 状态码 | 说明 | schema |
|
|
||||||
|-----|--------------|--------|
|
|
||||||
| 200 | OK | 提交结果 |
|
|
||||||
| 201 | Created | |
|
|
||||||
| 401 | Unauthorized | |
|
|
||||||
| 403 | Forbidden | |
|
|
||||||
| 404 | Not Found | |
|
|
||||||
|
|
||||||
**响应参数**:
|
|
||||||
|
|
||||||
| 参数名称 | 参数说明 | 类型 | schema |
|
|
||||||
|-------------|-------------------------------------------|----------------|----------------|
|
|
||||||
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
|
|
||||||
| description | 描述 | string | |
|
|
||||||
| properties | 扩展字段 | object | |
|
|
||||||
| result | 任务ID | string | |
|
|
||||||
|
|
||||||
**响应示例**:
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
{
|
|
||||||
"code"
|
|
||||||
:
|
|
||||||
1,
|
|
||||||
"description"
|
|
||||||
:
|
|
||||||
"提交成功",
|
|
||||||
"properties"
|
|
||||||
:
|
|
||||||
{
|
|
||||||
}
|
|
||||||
,
|
|
||||||
"result"
|
|
||||||
:
|
|
||||||
1320098173412546
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 任务查询
|
|
||||||
|
|
||||||
### 指定ID获取任务
|
|
||||||
|
|
||||||
**接口地址**:`/mj/task/{id}/fetch`
|
|
||||||
|
|
||||||
**请求方式**:`GET`
|
|
||||||
|
|
||||||
**请求数据类型**:`application/x-www-form-urlencoded`
|
|
||||||
|
|
||||||
**响应数据类型**:`*/*`
|
|
||||||
|
|
||||||
**接口描述**:
|
|
||||||
|
|
||||||
**请求参数**:
|
|
||||||
|
|
||||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
|
||||||
|------|------|------|-------|--------|--------|
|
|
||||||
| id | 任务ID | path | false | string | |
|
|
||||||
|
|
||||||
**响应状态**:
|
|
||||||
|
|
||||||
| 状态码 | 说明 | schema |
|
|
||||||
|-----|--------------|--------|
|
|
||||||
| 200 | OK | 任务 |
|
|
||||||
| 401 | Unauthorized | |
|
|
||||||
| 403 | Forbidden | |
|
|
||||||
| 404 | Not Found | |
|
|
||||||
|
|
||||||
**响应参数**:
|
|
||||||
|
|
||||||
| 参数名称 | 参数说明 | 类型 | schema |
|
|
||||||
|-------------|----------------------------------------------------------|----------------|----------------|
|
|
||||||
| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
|
|
||||||
| description | 任务描述 | string | |
|
|
||||||
| failReason | 失败原因 | string | |
|
|
||||||
| finishTime | 结束时间 | integer(int64) | integer(int64) |
|
|
||||||
| id | 任务ID | string | |
|
|
||||||
| imageUrl | 图片url | string | |
|
|
||||||
| progress | 任务进度 | string | |
|
|
||||||
| prompt | 提示词 | string | |
|
|
||||||
| promptEn | 提示词-英文 | string | |
|
|
||||||
| startTime | 开始执行时间 | integer(int64) | integer(int64) |
|
|
||||||
| state | 自定义参数 | string | |
|
|
||||||
| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
|
|
||||||
| submitTime | 提交时间 | integer(int64) | integer(int64) |
|
|
||||||
|
|
||||||
**响应示例**:
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
{
|
|
||||||
"action"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"description"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"failReason"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"finishTime"
|
|
||||||
:
|
|
||||||
0,
|
|
||||||
"id"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"imageUrl"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"progress"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"prompt"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"promptEn"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"startTime"
|
|
||||||
:
|
|
||||||
0,
|
|
||||||
"state"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"status"
|
|
||||||
:
|
|
||||||
"",
|
|
||||||
"submitTime"
|
|
||||||
:
|
|
||||||
0
|
|
||||||
}
|
|
||||||
```
|
|
@ -18,7 +18,7 @@
|
|||||||
此分叉版本的主要变更如下:
|
此分叉版本的主要变更如下:
|
||||||
|
|
||||||
1. 全新的UI界面(部分界面还待更新)
|
1. 全新的UI界面(部分界面还待更新)
|
||||||
2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持
|
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持
|
||||||
+ [x] /mj/submit/imagine
|
+ [x] /mj/submit/imagine
|
||||||
+ [x] /mj/submit/change
|
+ [x] /mj/submit/change
|
||||||
+ [x] /mj/submit/blend
|
+ [x] /mj/submit/blend
|
||||||
@ -26,6 +26,11 @@
|
|||||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||||
+ [x] /task/list-by-condition
|
+ [x] /task/list-by-condition
|
||||||
|
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||||
|
+ [x] /mj/submit/modal
|
||||||
|
+ [x] /mj/submit/shorten
|
||||||
|
+ [x] /mj/task/{id}/image-seed
|
||||||
|
+ [x] /mj/insight-face/swap (InsightFace)
|
||||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||||
+ [x] 易支付
|
+ [x] 易支付
|
||||||
4. 支持用key查询使用额度:
|
4. 支持用key查询使用额度:
|
||||||
@ -49,6 +54,7 @@
|
|||||||
2. 智谱glm-4v,glm-4v识图
|
2. 智谱glm-4v,glm-4v识图
|
||||||
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
||||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||||
|
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口
|
||||||
|
|
||||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ const (
|
|||||||
ChannelTypeMidjourney = 2
|
ChannelTypeMidjourney = 2
|
||||||
ChannelTypeAzure = 3
|
ChannelTypeAzure = 3
|
||||||
ChannelTypeOllama = 4
|
ChannelTypeOllama = 4
|
||||||
ChannelTypeOpenAISB = 5
|
ChannelTypeMidjourneyPlus = 5
|
||||||
ChannelTypeOpenAIMax = 6
|
ChannelTypeOpenAIMax = 6
|
||||||
ChannelTypeOhMyGPT = 7
|
ChannelTypeOhMyGPT = 7
|
||||||
ChannelTypeCustom = 8
|
ChannelTypeCustom = 8
|
||||||
|
@ -95,17 +95,31 @@ var ModelRatio = map[string]float64{
|
|||||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||||
}
|
}
|
||||||
|
|
||||||
var ModelPrice = map[string]float64{
|
var DefaultModelPrice = map[string]float64{
|
||||||
"gpt-4-gizmo-*": 0.1,
|
"gpt-4-gizmo-*": 0.1,
|
||||||
"mj_imagine": 0.1,
|
"mj_imagine": 0.1,
|
||||||
"mj_variation": 0.1,
|
"mj_variation": 0.1,
|
||||||
"mj_reroll": 0.1,
|
"mj_reroll": 0.1,
|
||||||
"mj_blend": 0.1,
|
"mj_blend": 0.1,
|
||||||
"mj_describe": 0.05,
|
"mj_modal": 0.1,
|
||||||
"mj_upscale": 0.05,
|
"mj_zoom": 0.1,
|
||||||
|
"mj_shorten": 0.1,
|
||||||
|
"mj_high_variation": 0.1,
|
||||||
|
"mj_low_variation": 0.1,
|
||||||
|
"mj_pan": 0.1,
|
||||||
|
"mj_inpaint": 0,
|
||||||
|
"mj_custom_zoom": 0,
|
||||||
|
"mj_describe": 0.05,
|
||||||
|
"mj_upscale": 0.05,
|
||||||
|
"swap_face": 0.05,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ModelPrice = map[string]float64{}
|
||||||
|
|
||||||
func ModelPrice2JSONString() string {
|
func ModelPrice2JSONString() string {
|
||||||
|
if len(ModelPrice) == 0 {
|
||||||
|
ModelPrice = DefaultModelPrice
|
||||||
|
}
|
||||||
jsonBytes, err := json.Marshal(ModelPrice)
|
jsonBytes, err := json.Marshal(ModelPrice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError("error marshalling model price: " + err.Error())
|
SysError("error marshalling model price: " + err.Error())
|
||||||
@ -119,6 +133,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetModelPrice(name string, printErr bool) float64 {
|
func GetModelPrice(name string, printErr bool) float64 {
|
||||||
|
if len(ModelPrice) == 0 {
|
||||||
|
ModelPrice = DefaultModelPrice
|
||||||
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
name = "gpt-4-gizmo-*"
|
name = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
|
42
constant/midjourney.go
Normal file
42
constant/midjourney.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const (
|
||||||
|
MjErrorUnknown = 5
|
||||||
|
MjRequestError = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MjActionImagine = "IMAGINE"
|
||||||
|
MjActionDescribe = "DESCRIBE"
|
||||||
|
MjActionBlend = "BLEND"
|
||||||
|
MjActionUpscale = "UPSCALE"
|
||||||
|
MjActionVariation = "VARIATION"
|
||||||
|
MjActionReRoll = "REROLL"
|
||||||
|
MjActionInPaint = "INPAINT"
|
||||||
|
MjActionModal = "MODAL"
|
||||||
|
MjActionZoom = "ZOOM"
|
||||||
|
MjActionCustomZoom = "CUSTOM_ZOOM"
|
||||||
|
MjActionShorten = "SHORTEN"
|
||||||
|
MjActionHighVariation = "HIGH_VARIATION"
|
||||||
|
MjActionLowVariation = "LOW_VARIATION"
|
||||||
|
MjActionPan = "PAN"
|
||||||
|
MjActionSwapFace = "SWAP_FACE"
|
||||||
|
)
|
||||||
|
|
||||||
|
var MidjourneyModel2Action = map[string]string{
|
||||||
|
"mj_imagine": MjActionImagine,
|
||||||
|
"mj_describe": MjActionDescribe,
|
||||||
|
"mj_blend": MjActionBlend,
|
||||||
|
"mj_upscale": MjActionUpscale,
|
||||||
|
"mj_variation": MjActionVariation,
|
||||||
|
"mj_reroll": MjActionReRoll,
|
||||||
|
"mj_modal": MjActionModal,
|
||||||
|
"mj_inpaint": MjActionInPaint,
|
||||||
|
"mj_zoom": MjActionZoom,
|
||||||
|
"mj_custom_zoom": MjActionCustomZoom,
|
||||||
|
"mj_shorten": MjActionShorten,
|
||||||
|
"mj_high_variation": MjActionHighVariation,
|
||||||
|
"mj_low_variation": MjActionLowVariation,
|
||||||
|
"mj_pan": MjActionPan,
|
||||||
|
"swap_face": MjActionSwapFace,
|
||||||
|
}
|
@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
case common.ChannelTypeCustom:
|
case common.ChannelTypeCustom:
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
case common.ChannelTypeOpenAISB:
|
//case common.ChannelTypeOpenAISB:
|
||||||
return updateChannelOpenAISBBalance(channel)
|
// return updateChannelOpenAISBBalance(channel)
|
||||||
case common.ChannelTypeAIProxy:
|
case common.ChannelTypeAIProxy:
|
||||||
return updateChannelAIProxyBalance(channel)
|
return updateChannelAIProxyBalance(channel)
|
||||||
case common.ChannelTypeAPI2GPT:
|
case common.ChannelTypeAPI2GPT:
|
||||||
|
@ -10,145 +10,14 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relay2 "one-api/relay"
|
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*func UpdateMidjourneyTask() {
|
|
||||||
//revocer
|
|
||||||
//imageModel := "midjourney"
|
|
||||||
ctx := context.TODO()
|
|
||||||
imageModel := "midjourney"
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask panic: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
for {
|
|
||||||
time.Sleep(time.Duration(15) * time.Second)
|
|
||||||
tasks := model.GetAllUnFinishTasks()
|
|
||||||
if len(tasks) != 0 {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
|
||||||
for _, task := range tasks {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
|
|
||||||
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
|
|
||||||
task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
|
|
||||||
task.Status = "FAILURE"
|
|
||||||
task.Progress = "100%"
|
|
||||||
err := task.Update()
|
|
||||||
if err != nil {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
|
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
|
|
||||||
if err != nil {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置超时时间
|
|
||||||
timeout := time.Second * 5
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
|
|
||||||
// 使用带有超时的 context 创建新的请求
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
|
|
||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
log.Printf("responseBody: %s", string(responseBody))
|
|
||||||
var responseItem Midjourney
|
|
||||||
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
|
|
||||||
err = json.Unmarshal(responseBody, &responseItem)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
|
|
||||||
var responseWithoutStatus MidjourneyWithoutStatus
|
|
||||||
var responseStatus MidjourneyStatus
|
|
||||||
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
|
|
||||||
err2 := json.Unmarshal(responseBody, &responseStatus)
|
|
||||||
if err1 == nil && err2 == nil {
|
|
||||||
jsonData, err3 := json.Marshal(responseWithoutStatus)
|
|
||||||
if err3 != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask error1: %v", err3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err4 := json.Unmarshal(jsonData, &responseStatus)
|
|
||||||
if err4 != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask error2: %v", err4)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
responseItem.Status = strconv.Itoa(responseStatus.Status)
|
|
||||||
} else {
|
|
||||||
log.Printf("UpdateMidjourneyTask error3: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Printf("UpdateMidjourneyTask error4: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
task.Code = 1
|
|
||||||
task.Progress = responseItem.Progress
|
|
||||||
task.PromptEn = responseItem.PromptEn
|
|
||||||
task.State = responseItem.State
|
|
||||||
task.SubmitTime = responseItem.SubmitTime
|
|
||||||
task.StartTime = responseItem.StartTime
|
|
||||||
task.FinishTime = responseItem.FinishTime
|
|
||||||
task.ImageUrl = responseItem.ImageUrl
|
|
||||||
task.Status = responseItem.Status
|
|
||||||
task.FailReason = responseItem.FailReason
|
|
||||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
|
||||||
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
|
|
||||||
task.Progress = "100%"
|
|
||||||
err = model.CacheUpdateUserQuota(task.UserId)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("error update user quota cache: " + err.Error())
|
|
||||||
} else {
|
|
||||||
modelRatio := common.GetModelRatio(imageModel)
|
|
||||||
groupRatio := common.GetGroupRatio("default")
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
quota := int(ratio * 1 * 1000)
|
|
||||||
if quota != 0 {
|
|
||||||
err := model.IncreaseUserQuota(task.UserId, quota)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("fail to increase user quota")
|
|
||||||
}
|
|
||||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = task.Update()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask error5: %v", err)
|
|
||||||
}
|
|
||||||
log.Printf("UpdateMidjourneyTask success: %v", task)
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func UpdateMidjourneyTaskBulk() {
|
func UpdateMidjourneyTaskBulk() {
|
||||||
//imageModel := "midjourney"
|
//imageModel := "midjourney"
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
@ -228,12 +97,16 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
|
continue
|
||||||
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var responseItems []relay2.Midjourney
|
var responseItems []dto.MidjourneyDto
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
@ -245,10 +118,16 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
|
|
||||||
for _, responseItem := range responseItems {
|
for _, responseItem := range responseItems {
|
||||||
task := taskM[responseItem.MjId]
|
task := taskM[responseItem.MjId]
|
||||||
|
|
||||||
|
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
|
||||||
|
// 如果时间超过一小时,且进度不是100%,则认为任务失败
|
||||||
|
if useTime > 3600000 && task.Progress != "100%" {
|
||||||
|
responseItem.FailReason = "上游任务超时(超过1小时)"
|
||||||
|
responseItem.Status = "FAILURE"
|
||||||
|
}
|
||||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
task.Code = 1
|
task.Code = 1
|
||||||
task.Progress = responseItem.Progress
|
task.Progress = responseItem.Progress
|
||||||
task.PromptEn = responseItem.PromptEn
|
task.PromptEn = responseItem.PromptEn
|
||||||
@ -259,6 +138,15 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
task.ImageUrl = responseItem.ImageUrl
|
task.ImageUrl = responseItem.ImageUrl
|
||||||
task.Status = responseItem.Status
|
task.Status = responseItem.Status
|
||||||
task.FailReason = responseItem.FailReason
|
task.FailReason = responseItem.FailReason
|
||||||
|
if responseItem.Properties != nil {
|
||||||
|
propertiesStr, _ := json.Marshal(responseItem.Properties)
|
||||||
|
task.Properties = string(propertiesStr)
|
||||||
|
}
|
||||||
|
if responseItem.Buttons != nil {
|
||||||
|
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||||
|
task.Buttons = string(buttonStr)
|
||||||
|
}
|
||||||
|
|
||||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
@ -286,7 +174,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
|
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
|
||||||
if oldTask.Code != 1 {
|
if oldTask.Code != 1 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -4,12 +4,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/channel/ai360"
|
"one-api/relay/channel/ai360"
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
"one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@ -59,8 +60,8 @@ func init() {
|
|||||||
IsBlocking: false,
|
IsBlocking: false,
|
||||||
})
|
})
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
for i := 0; i < constant.APITypeDummy; i++ {
|
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||||
if i == constant.APITypeAIProxyLibrary {
|
if i == relayconstant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
adaptor := relay.GetAdaptor(i)
|
adaptor := relay.GetAdaptor(i)
|
||||||
@ -100,6 +101,17 @@ func init() {
|
|||||||
Parent: nil,
|
Parent: nil,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||||
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
|
Id: modelName,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1626777600,
|
||||||
|
OwnedBy: "midjourney",
|
||||||
|
Permission: permission,
|
||||||
|
Root: modelName,
|
||||||
|
Parent: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
openAIModelsMap = make(map[string]OpenAIModels)
|
openAIModelsMap = make(map[string]OpenAIModels)
|
||||||
for _, model := range openAIModels {
|
for _, model := range openAIModels {
|
||||||
openAIModelsMap[model.Id] = model
|
openAIModelsMap[model.Id] = model
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
@ -61,60 +60,35 @@ func Relay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
relayMode := relayconstant.RelayModeUnknown
|
relayMode := c.GetInt("relay_mode")
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
|
|
||||||
relayMode = relayconstant.RelayModeMidjourneyImagine
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
|
|
||||||
relayMode = relayconstant.RelayModeMidjourneyBlend
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
|
|
||||||
relayMode = relayconstant.RelayModeMidjourneyDescribe
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
|
|
||||||
relayMode = relayconstant.RelayModeMidjourneyNotify
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
|
|
||||||
relayMode = relayconstant.RelayModeMidjourneyChange
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
|
|
||||||
relayMode = relayconstant.RelayModeMidjourneyChange
|
|
||||||
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
|
|
||||||
relayMode = relayconstant.RelayModeMidjourneyTaskFetch
|
|
||||||
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
|
|
||||||
relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
|
|
||||||
}
|
|
||||||
|
|
||||||
var err *dto.MidjourneyResponse
|
var err *dto.MidjourneyResponse
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeMidjourneyNotify:
|
case relayconstant.RelayModeMidjourneyNotify:
|
||||||
err = relay.RelayMidjourneyNotify(c)
|
err = relay.RelayMidjourneyNotify(c)
|
||||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
err = relay.RelayMidjourneyTask(c, relayMode)
|
||||||
|
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||||
|
err = relay.RelayMidjourneyTaskImageSeed(c)
|
||||||
|
case relayconstant.RelayModeSwapFace:
|
||||||
|
err = relay.RelaySwapFace(c)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
err = relay.RelayMidjourneySubmit(c, relayMode)
|
||||||
}
|
}
|
||||||
//err = relayMidjourneySubmit(c, relayMode)
|
//err = relayMidjourneySubmit(c, relayMode)
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retryTimesStr := c.Query("retry")
|
statusCode := http.StatusBadRequest
|
||||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
if err.Code == 30 {
|
||||||
if retryTimesStr == "" {
|
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||||
retryTimes = common.RetryTimes
|
statusCode = http.StatusTooManyRequests
|
||||||
}
|
|
||||||
if retryTimes > 0 {
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
|
||||||
} else {
|
|
||||||
if err.Code == 30 {
|
|
||||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
|
||||||
}
|
|
||||||
c.JSON(429, gin.H{
|
|
||||||
"error": fmt.Sprintf("%s %s", err.Description, err.Result),
|
|
||||||
"type": "upstream_error",
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||||
|
"type": "upstream_error",
|
||||||
|
"code": err.Code,
|
||||||
|
})
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||||
//if shouldDisableChannel(&err.Error) {
|
|
||||||
// channelId := c.GetInt("channel_id")
|
|
||||||
// channelName := c.GetString("channel_name")
|
|
||||||
// disableChannel(channelId, channelName, err.Result)
|
|
||||||
//};''''''''''''''''''''''''''''''''
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,21 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
//type SimpleMjRequest struct {
|
||||||
|
// Prompt string `json:"prompt"`
|
||||||
|
// CustomId string `json:"customId"`
|
||||||
|
// Action string `json:"action"`
|
||||||
|
// Content string `json:"content"`
|
||||||
|
//}
|
||||||
|
|
||||||
|
type SwapFaceRequest struct {
|
||||||
|
SourceBase64 string `json:"sourceBase64"`
|
||||||
|
TargetBase64 string `json:"targetBase64"`
|
||||||
|
}
|
||||||
|
|
||||||
type MidjourneyRequest struct {
|
type MidjourneyRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
CustomId string `json:"customId"`
|
||||||
|
BotType string `json:"botType"`
|
||||||
NotifyHook string `json:"notifyHook"`
|
NotifyHook string `json:"notifyHook"`
|
||||||
Action string `json:"action"`
|
Action string `json:"action"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
@ -9,6 +23,7 @@ type MidjourneyRequest struct {
|
|||||||
TaskId string `json:"taskId"`
|
TaskId string `json:"taskId"`
|
||||||
Base64Array []string `json:"base64Array"`
|
Base64Array []string `json:"base64Array"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
MaskBase64 string `json:"maskBase64"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidjourneyResponse struct {
|
type MidjourneyResponse struct {
|
||||||
@ -17,3 +32,64 @@ type MidjourneyResponse struct {
|
|||||||
Properties interface{} `json:"properties"`
|
Properties interface{} `json:"properties"`
|
||||||
Result string `json:"result"`
|
Result string `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MidjourneyResponseWithStatusCode struct {
|
||||||
|
StatusCode int `json:"statusCode"`
|
||||||
|
Response MidjourneyResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyDto struct {
|
||||||
|
MjId string `json:"id"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
CustomId string `json:"customId"`
|
||||||
|
BotType string `json:"botType"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
PromptEn string `json:"promptEn"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
SubmitTime int64 `json:"submitTime"`
|
||||||
|
StartTime int64 `json:"startTime"`
|
||||||
|
FinishTime int64 `json:"finishTime"`
|
||||||
|
ImageUrl string `json:"imageUrl"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Progress string `json:"progress"`
|
||||||
|
FailReason string `json:"failReason"`
|
||||||
|
Buttons any `json:"buttons"`
|
||||||
|
MaskBase64 string `json:"maskBase64"`
|
||||||
|
Properties *Properties `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyStatus struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
|
type MidjourneyWithoutStatus struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
MjId string `json:"mj_id" gorm:"index"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
PromptEn string `json:"prompt_en"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
State string `json:"state"`
|
||||||
|
SubmitTime int64 `json:"submit_time"`
|
||||||
|
StartTime int64 `json:"start_time"`
|
||||||
|
FinishTime int64 `json:"finish_time"`
|
||||||
|
ImageUrl string `json:"image_url"`
|
||||||
|
Progress string `json:"progress"`
|
||||||
|
FailReason string `json:"fail_reason"`
|
||||||
|
ChannelId int `json:"channel_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ActionButton struct {
|
||||||
|
CustomId any `json:"customId"`
|
||||||
|
Emoji any `json:"emoji"`
|
||||||
|
Label any `json:"label"`
|
||||||
|
Type any `json:"type"`
|
||||||
|
Style any `json:"style"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Properties struct {
|
||||||
|
FinalPrompt string `json:"finalPrompt"`
|
||||||
|
FinalZhPrompt string `json:"finalZhPrompt"`
|
||||||
|
}
|
||||||
|
@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !userEnabled {
|
if !userEnabled {
|
||||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
@ -125,17 +125,11 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
c.Set("token_model_limit_enabled", false)
|
c.Set("token_model_limit_enabled", false)
|
||||||
}
|
}
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
consumeQuota := true
|
|
||||||
if strings.HasPrefix(requestURL, "/v1/models") {
|
|
||||||
consumeQuota = false
|
|
||||||
}
|
|
||||||
c.Set("consume_quota", consumeQuota)
|
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("channelId", parts[1])
|
||||||
} else {
|
} else {
|
||||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -23,32 +27,59 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
shouldSelectChannel := true
|
||||||
// Select a channel for the user
|
// Select a channel for the user
|
||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
var err error
|
var err error
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
||||||
// Midjourney
|
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||||
if modelRequest.Model == "" {
|
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||||
modelRequest.Model = "midjourney"
|
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||||
|
relayMode == relayconstant.RelayModeMidjourneyNotify ||
|
||||||
|
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
|
||||||
|
shouldSelectChannel = false
|
||||||
|
} else {
|
||||||
|
midjourneyRequest := dto.MidjourneyRequest{}
|
||||||
|
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
||||||
|
if err != nil {
|
||||||
|
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||||
|
if mjErr != nil {
|
||||||
|
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if midjourneyModel == "" {
|
||||||
|
if !success {
|
||||||
|
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// task fetch, task fetch by condition, notify
|
||||||
|
shouldSelectChannel = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelRequest.Model = midjourneyModel
|
||||||
}
|
}
|
||||||
|
c.Set("relay_mode", relayMode)
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
@ -87,60 +118,61 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if tokenModelLimit != nil {
|
if tokenModelLimit != nil {
|
||||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
||||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// token model limit is empty, all models are not allowed
|
// token model limit is empty, all models are not allowed
|
||||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
c.Set("group", userGroup)
|
c.Set("group", userGroup)
|
||||||
|
if shouldSelectChannel {
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
|
}
|
||||||
|
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||||
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if channel == nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("channel", channel.Type)
|
||||||
|
c.Set("channel_id", channel.Id)
|
||||||
|
c.Set("channel_name", channel.Name)
|
||||||
|
ban := true
|
||||||
|
// parse *int to bool
|
||||||
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
|
ban = false
|
||||||
|
}
|
||||||
|
c.Set("auto_ban", ban)
|
||||||
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
// TODO: api_version统一
|
||||||
|
switch channel.Type {
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
//case common.ChannelTypeAIProxyLibrary:
|
||||||
|
// c.Set("library_id", channel.Other)
|
||||||
|
case common.ChannelTypeGemini:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
c.Set("plugin", channel.Other)
|
||||||
}
|
}
|
||||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
|
||||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if channel == nil {
|
|
||||||
abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Set("channel", channel.Type)
|
|
||||||
c.Set("channel_id", channel.Id)
|
|
||||||
c.Set("channel_name", channel.Name)
|
|
||||||
ban := true
|
|
||||||
// parse *int to bool
|
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
|
||||||
ban = false
|
|
||||||
}
|
|
||||||
c.Set("auto_ban", ban)
|
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
|
||||||
// TODO: api_version统一
|
|
||||||
switch channel.Type {
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
//case common.ChannelTypeAIProxyLibrary:
|
|
||||||
// c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
c.Set("plugin", channel.Other)
|
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
common.LogError(c.Request.Context(), message)
|
common.LogError(c.Request.Context(), message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"description": description,
|
||||||
|
"type": "new_api_error",
|
||||||
|
"code": code,
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
common.LogError(c.Request.Context(), description)
|
||||||
|
}
|
||||||
|
@ -147,7 +147,12 @@ func FixAbility() (int, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
var channels []Channel
|
var channels []Channel
|
||||||
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
|
|
||||||
|
if len(abilityChannelIds) == 0 {
|
||||||
|
err = DB.Find(&channels).Error
|
||||||
|
} else {
|
||||||
|
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,8 @@ type Midjourney struct {
|
|||||||
FailReason string `json:"fail_reason"`
|
FailReason string `json:"fail_reason"`
|
||||||
ChannelId int `json:"channel_id"`
|
ChannelId int `json:"channel_id"`
|
||||||
Quota int `json:"quota"`
|
Quota int `json:"quota"`
|
||||||
|
Buttons string `json:"buttons"`
|
||||||
|
Properties string `json:"properties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
||||||
|
@ -17,10 +17,15 @@ const (
|
|||||||
RelayModeMidjourneySimpleChange
|
RelayModeMidjourneySimpleChange
|
||||||
RelayModeMidjourneyNotify
|
RelayModeMidjourneyNotify
|
||||||
RelayModeMidjourneyTaskFetch
|
RelayModeMidjourneyTaskFetch
|
||||||
|
RelayModeMidjourneyTaskImageSeed
|
||||||
RelayModeMidjourneyTaskFetchByCondition
|
RelayModeMidjourneyTaskFetchByCondition
|
||||||
RelayModeAudioSpeech
|
RelayModeAudioSpeech
|
||||||
RelayModeAudioTranscription
|
RelayModeAudioTranscription
|
||||||
RelayModeAudioTranslation
|
RelayModeAudioTranslation
|
||||||
|
RelayModeMidjourneyAction
|
||||||
|
RelayModeMidjourneyModal
|
||||||
|
RelayModeMidjourneyShorten
|
||||||
|
RelayModeSwapFace
|
||||||
)
|
)
|
||||||
|
|
||||||
func Path2RelayMode(path string) int {
|
func Path2RelayMode(path string) int {
|
||||||
@ -48,3 +53,39 @@ func Path2RelayMode(path string) int {
|
|||||||
}
|
}
|
||||||
return relayMode
|
return relayMode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Path2RelayModeMidjourney(path string) int {
|
||||||
|
relayMode := RelayModeUnknown
|
||||||
|
if strings.HasPrefix(path, "/mj/submit/action") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = RelayModeMidjourneyAction
|
||||||
|
} else if strings.HasPrefix(path, "/mj/submit/modal") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = RelayModeMidjourneyModal
|
||||||
|
} else if strings.HasPrefix(path, "/mj/submit/shorten") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = RelayModeMidjourneyShorten
|
||||||
|
} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = RelayModeSwapFace
|
||||||
|
} else if strings.HasPrefix(path, "/mj/submit/imagine") {
|
||||||
|
relayMode = RelayModeMidjourneyImagine
|
||||||
|
} else if strings.HasPrefix(path, "/mj/submit/blend") {
|
||||||
|
relayMode = RelayModeMidjourneyBlend
|
||||||
|
} else if strings.HasPrefix(path, "/mj/submit/describe") {
|
||||||
|
relayMode = RelayModeMidjourneyDescribe
|
||||||
|
} else if strings.HasPrefix(path, "/mj/notify") {
|
||||||
|
relayMode = RelayModeMidjourneyNotify
|
||||||
|
} else if strings.HasPrefix(path, "/mj/submit/change") {
|
||||||
|
relayMode = RelayModeMidjourneyChange
|
||||||
|
} else if strings.HasPrefix(path, "/mj/submit/simple-change") {
|
||||||
|
relayMode = RelayModeMidjourneyChange
|
||||||
|
} else if strings.HasSuffix(path, "/fetch") {
|
||||||
|
relayMode = RelayModeMidjourneyTaskFetch
|
||||||
|
} else if strings.HasSuffix(path, "/image-seed") {
|
||||||
|
relayMode = RelayModeMidjourneyTaskImageSeed
|
||||||
|
} else if strings.HasSuffix(path, "/list-by-condition") {
|
||||||
|
relayMode = RelayModeMidjourneyTaskFetchByCondition
|
||||||
|
}
|
||||||
|
return relayMode
|
||||||
|
}
|
||||||
|
@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
var imageRequest dto.ImageRequest
|
var imageRequest dto.ImageRequest
|
||||||
if consumeQuota {
|
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
if err != nil {
|
||||||
if err != nil {
|
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if imageRequest.Model == "" {
|
if imageRequest.Model == "" {
|
||||||
@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
|
|
||||||
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
|
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
|
||||||
|
|
||||||
if consumeQuota && userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,46 +173,42 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
var textResponse dto.ImageResponse
|
var textResponse dto.ImageResponse
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||||
if consumeQuota {
|
if resp.StatusCode != http.StatusOK {
|
||||||
if resp.StatusCode != http.StatusOK {
|
return
|
||||||
return
|
}
|
||||||
}
|
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
if err != nil {
|
||||||
if err != nil {
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
}
|
||||||
}
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
if err != nil {
|
||||||
if err != nil {
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
}
|
||||||
}
|
if quota != 0 {
|
||||||
if quota != 0 {
|
tokenName := c.GetString("token_name")
|
||||||
tokenName := c.GetString("token_name")
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
channelId := c.GetInt("channel_id")
|
||||||
channelId := c.GetInt("channel_id")
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}(c.Request.Context())
|
||||||
|
|
||||||
if consumeQuota {
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
c.Writer.Header().Set(k, v[0])
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
@ -20,53 +21,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Midjourney struct {
|
|
||||||
MjId string `json:"id"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
PromptEn string `json:"promptEn"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
State string `json:"state"`
|
|
||||||
SubmitTime int64 `json:"submitTime"`
|
|
||||||
StartTime int64 `json:"startTime"`
|
|
||||||
FinishTime int64 `json:"finishTime"`
|
|
||||||
ImageUrl string `json:"imageUrl"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
FailReason string `json:"failReason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MidjourneyStatus struct {
|
|
||||||
Status int `json:"status"`
|
|
||||||
}
|
|
||||||
type MidjourneyWithoutStatus struct {
|
|
||||||
Id int `json:"id"`
|
|
||||||
Code int `json:"code"`
|
|
||||||
UserId int `json:"user_id" gorm:"index"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
MjId string `json:"mj_id" gorm:"index"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
PromptEn string `json:"prompt_en"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
State string `json:"state"`
|
|
||||||
SubmitTime int64 `json:"submit_time"`
|
|
||||||
StartTime int64 `json:"start_time"`
|
|
||||||
FinishTime int64 `json:"finish_time"`
|
|
||||||
ImageUrl string `json:"image_url"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
FailReason string `json:"fail_reason"`
|
|
||||||
ChannelId int `json:"channel_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var DefaultModelPrice = map[string]float64{
|
|
||||||
"mj_imagine": 0.1,
|
|
||||||
"mj_variation": 0.1,
|
|
||||||
"mj_reroll": 0.1,
|
|
||||||
"mj_blend": 0.1,
|
|
||||||
"mj_describe": 0.05,
|
|
||||||
"mj_upscale": 0.05,
|
|
||||||
}
|
|
||||||
|
|
||||||
func RelayMidjourneyImage(c *gin.Context) {
|
func RelayMidjourneyImage(c *gin.Context) {
|
||||||
taskId := c.Param("id")
|
taskId := c.Param("id")
|
||||||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||||
@ -108,7 +62,7 @@ func RelayMidjourneyImage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
|
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
|
||||||
var midjRequest Midjourney
|
var midjRequest dto.MidjourneyDto
|
||||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
@ -147,7 +101,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
|
func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
|
||||||
midjourneyTask.MjId = originTask.MjId
|
midjourneyTask.MjId = originTask.MjId
|
||||||
midjourneyTask.Progress = originTask.Progress
|
midjourneyTask.Progress = originTask.Progress
|
||||||
midjourneyTask.PromptEn = originTask.PromptEn
|
midjourneyTask.PromptEn = originTask.PromptEn
|
||||||
@ -167,9 +121,182 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
|
|||||||
midjourneyTask.Action = originTask.Action
|
midjourneyTask.Action = originTask.Action
|
||||||
midjourneyTask.Description = originTask.Description
|
midjourneyTask.Description = originTask.Description
|
||||||
midjourneyTask.Prompt = originTask.Prompt
|
midjourneyTask.Prompt = originTask.Prompt
|
||||||
|
if originTask.Buttons != "" {
|
||||||
|
var buttons []dto.ActionButton
|
||||||
|
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
|
||||||
|
if err == nil {
|
||||||
|
midjourneyTask.Buttons = buttons
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if originTask.Properties != "" {
|
||||||
|
var properties dto.Properties
|
||||||
|
err := json.Unmarshal([]byte(originTask.Properties), &properties)
|
||||||
|
if err == nil {
|
||||||
|
midjourneyTask.Properties = &properties
|
||||||
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||||
|
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
group := c.GetString("group")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
var swapFaceRequest dto.SwapFaceRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||||
|
}
|
||||||
|
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||||
|
}
|
||||||
|
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||||
|
modelPrice := common.GetModelPrice(modelName, true)
|
||||||
|
// 如果没有配置价格,则使用默认价格
|
||||||
|
if modelPrice == -1 {
|
||||||
|
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||||
|
if !ok {
|
||||||
|
modelPrice = 0.1
|
||||||
|
} else {
|
||||||
|
modelPrice = defaultPrice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
groupRatio := common.GetGroupRatio(group)
|
||||||
|
ratio := modelPrice * groupRatio
|
||||||
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
return &dto.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
quota := int(ratio * common.QuotaPerUnit)
|
||||||
|
|
||||||
|
if userQuota-quota < 0 {
|
||||||
|
return &dto.MidjourneyResponse{
|
||||||
|
Code: 4,
|
||||||
|
Description: "quota_not_enough",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requestURL := c.Request.URL.String()
|
||||||
|
baseURL := c.GetString("base_url")
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
|
||||||
|
if err != nil {
|
||||||
|
return &mjResp.Response
|
||||||
|
}
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
|
||||||
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(c.Request.Context())
|
||||||
|
midjResponse := &mjResp.Response
|
||||||
|
midjourneyTask := &model.Midjourney{
|
||||||
|
UserId: userId,
|
||||||
|
Code: midjResponse.Code,
|
||||||
|
Action: constant.MjActionSwapFace,
|
||||||
|
MjId: midjResponse.Result,
|
||||||
|
Prompt: "InsightFace",
|
||||||
|
PromptEn: "",
|
||||||
|
Description: midjResponse.Description,
|
||||||
|
State: "",
|
||||||
|
SubmitTime: startTime,
|
||||||
|
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||||
|
FinishTime: 0,
|
||||||
|
ImageUrl: "",
|
||||||
|
Status: "",
|
||||||
|
Progress: "0%",
|
||||||
|
FailReason: "",
|
||||||
|
ChannelId: c.GetInt("channel_id"),
|
||||||
|
Quota: quota,
|
||||||
|
}
|
||||||
|
err = midjourneyTask.Insert()
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(mjResp.StatusCode)
|
||||||
|
respBody, err := json.Marshal(midjResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||||
|
}
|
||||||
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||||
|
taskId := c.Param("id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
originTask := model.GetByMJId(userId, taskId)
|
||||||
|
if originTask == nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
|
||||||
|
}
|
||||||
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||||
|
}
|
||||||
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||||||
|
}
|
||||||
|
c.Set("channel_id", originTask.ChannelId)
|
||||||
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
|
||||||
|
requestURL := c.Request.URL.String()
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
||||||
|
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
||||||
|
if err != nil {
|
||||||
|
return &midjResponseWithStatus.Response
|
||||||
|
}
|
||||||
|
//defer func(ctx context.Context) {
|
||||||
|
// err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||||
|
// if err != nil {
|
||||||
|
// common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
// }
|
||||||
|
// err = model.CacheUpdateUserQuota(userId)
|
||||||
|
// if err != nil {
|
||||||
|
// common.SysError("error update user quota cache: " + err.Error())
|
||||||
|
// }
|
||||||
|
// if quota != 0 {
|
||||||
|
// tokenName := c.GetString("token_name")
|
||||||
|
// logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||||
|
// model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||||
|
// model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
// channelId := c.GetInt("channel_id")
|
||||||
|
// model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
// }
|
||||||
|
//}(c.Request.Context())
|
||||||
|
midjResponse := &midjResponseWithStatus.Response
|
||||||
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||||
|
respBody, err := json.Marshal(midjResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||||
|
}
|
||||||
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
|
if err != nil {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
var err error
|
var err error
|
||||||
@ -184,7 +311,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
|||||||
Description: "task_no_found",
|
Description: "task_no_found",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
midjourneyTask := coverMidjourneyTaskDto(c, originTask)
|
||||||
respBody, err = json.Marshal(midjourneyTask)
|
respBody, err = json.Marshal(midjourneyTask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
@ -203,16 +330,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
|||||||
Description: "do_request_failed",
|
Description: "do_request_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var tasks []Midjourney
|
var tasks []dto.MidjourneyDto
|
||||||
if len(condition.IDs) != 0 {
|
if len(condition.IDs) != 0 {
|
||||||
originTasks := model.GetByMJIds(userId, condition.IDs)
|
originTasks := model.GetByMJIds(userId, condition.IDs)
|
||||||
for _, originTask := range originTasks {
|
for _, originTask := range originTasks {
|
||||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
midjourneyTask := coverMidjourneyTaskDto(c, originTask)
|
||||||
tasks = append(tasks, midjourneyTask)
|
tasks = append(tasks, midjourneyTask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if tasks == nil {
|
if tasks == nil {
|
||||||
tasks = make([]Midjourney, 0)
|
tasks = make([]dto.MidjourneyDto, 0)
|
||||||
}
|
}
|
||||||
respBody, err = json.Marshal(tasks)
|
respBody, err = json.Marshal(tasks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -235,170 +362,115 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
// type 1 根据 mode 价格不同
|
|
||||||
MJSubmitActionImagine = "IMAGINE"
|
|
||||||
MJSubmitActionVariation = "VARIATION" //变换
|
|
||||||
MJSubmitActionBlend = "BLEND" //混图
|
|
||||||
|
|
||||||
MJSubmitActionReroll = "REROLL" //重新生成
|
|
||||||
// type 2 固定价格
|
|
||||||
MJSubmitActionDescribe = "DESCRIBE"
|
|
||||||
MJSubmitActionUpscale = "UPSCALE" // 放大
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||||
imageModel := "midjourney"
|
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
//channelType := c.GetInt("channel")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
consumeQuota := true
|
||||||
var midjRequest dto.MidjourneyRequest
|
var midjRequest dto.MidjourneyRequest
|
||||||
if consumeQuota {
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
if err != nil {
|
||||||
if err != nil {
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||||
return &dto.MidjourneyResponse{
|
}
|
||||||
Code: 4,
|
|
||||||
Description: "bind_request_body_failed",
|
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||||
}
|
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
||||||
|
if mjErr != nil {
|
||||||
|
return mjErr
|
||||||
}
|
}
|
||||||
|
relayMode = relayconstant.RelayModeMidjourneyChange
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||||
if midjRequest.Prompt == "" {
|
if midjRequest.Prompt == "" {
|
||||||
return &dto.MidjourneyResponse{
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
|
||||||
Code: 4,
|
|
||||||
Description: "prompt_is_required",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
midjRequest.Action = "IMAGINE"
|
midjRequest.Action = constant.MjActionImagine
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||||
midjRequest.Action = "DESCRIBE"
|
midjRequest.Action = constant.MjActionDescribe
|
||||||
|
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||||||
|
midjRequest.Action = constant.MjActionShorten
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||||
midjRequest.Action = "BLEND"
|
midjRequest.Action = constant.MjActionBlend
|
||||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||||
mjId := ""
|
mjId := ""
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
||||||
if midjRequest.TaskId == "" {
|
if midjRequest.TaskId == "" {
|
||||||
return &dto.MidjourneyResponse{
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||||||
Code: 4,
|
|
||||||
Description: "taskId_is_required",
|
|
||||||
}
|
|
||||||
} else if midjRequest.Action == "" {
|
} else if midjRequest.Action == "" {
|
||||||
return &dto.MidjourneyResponse{
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
|
||||||
Code: 4,
|
|
||||||
Description: "action_is_required",
|
|
||||||
}
|
|
||||||
} else if midjRequest.Index == 0 {
|
} else if midjRequest.Index == 0 {
|
||||||
return &dto.MidjourneyResponse{
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
|
||||||
Code: 4,
|
|
||||||
Description: "index_can_only_be_1_2_3_4",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
//action = midjRequest.Action
|
//action = midjRequest.Action
|
||||||
mjId = midjRequest.TaskId
|
mjId = midjRequest.TaskId
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
||||||
if midjRequest.Content == "" {
|
if midjRequest.Content == "" {
|
||||||
return &dto.MidjourneyResponse{
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
||||||
Code: 4,
|
|
||||||
Description: "content_is_required",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
params := convertSimpleChangeParams(midjRequest.Content)
|
params := service.ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
if params == nil {
|
if params == nil {
|
||||||
return &dto.MidjourneyResponse{
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
|
||||||
Code: 4,
|
|
||||||
Description: "content_parse_failed",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
mjId = params.ID
|
mjId = params.TaskId
|
||||||
midjRequest.Action = params.Action
|
midjRequest.Action = params.Action
|
||||||
|
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
||||||
|
//if midjRequest.MaskBase64 == "" {
|
||||||
|
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
||||||
|
//}
|
||||||
|
mjId = midjRequest.TaskId
|
||||||
|
midjRequest.Action = constant.MjActionModal
|
||||||
}
|
}
|
||||||
|
|
||||||
originTask := model.GetByMJId(userId, mjId)
|
originTask := model.GetByMJId(userId, mjId)
|
||||||
if originTask == nil {
|
if originTask == nil {
|
||||||
return &dto.MidjourneyResponse{
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||||
Code: 4,
|
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||||
Description: "task_no_found",
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||||
}
|
|
||||||
} else if originTask.Action == "UPSCALE" {
|
|
||||||
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "upscale_task_can_not_be_change",
|
|
||||||
}
|
|
||||||
} else if originTask.Status != "SUCCESS" {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "task_status_is_not_success",
|
|
||||||
}
|
|
||||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||||
channel, err := model.GetChannelById(originTask.ChannelId, false)
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||||
Code: 4,
|
}
|
||||||
Description: "channel_not_found",
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
}
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||||||
}
|
}
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
c.Set("channel_id", originTask.ChannelId)
|
c.Set("channel_id", originTask.ChannelId)
|
||||||
log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||||
}
|
}
|
||||||
midjRequest.Prompt = originTask.Prompt
|
midjRequest.Prompt = originTask.Prompt
|
||||||
|
|
||||||
|
//if channelType == common.ChannelTypeMidjourneyPlus {
|
||||||
|
// // plus
|
||||||
|
//} else {
|
||||||
|
// // 普通版渠道
|
||||||
|
//
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
// map model name
|
if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
|
||||||
modelMapping := c.GetString("model_mapping")
|
consumeQuota = false
|
||||||
isModelMapped := false
|
|
||||||
if modelMapping != "" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "unmarshal_model_mapping_failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if modelMap[imageModel] != "" {
|
|
||||||
imageModel = modelMap[imageModel]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
//baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
|
|
||||||
if c.GetString("base_url") != "" {
|
baseURL := c.GetString("base_url")
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||||
|
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
log.Printf("fullRequestURL: %s", fullRequestURL)
|
|
||||||
|
|
||||||
var requestBody io.Reader
|
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||||
if isModelMapped {
|
modelPrice := common.GetModelPrice(modelName, true)
|
||||||
jsonStr, err := json.Marshal(midjRequest)
|
|
||||||
if err != nil {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "marshal_text_request_failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
|
|
||||||
modelPrice := common.GetModelPrice(mjAction, true)
|
|
||||||
// 如果没有配置价格,则使用默认价格
|
// 如果没有配置价格,则使用默认价格
|
||||||
if modelPrice == -1 {
|
if modelPrice == -1 {
|
||||||
defaultPrice, ok := DefaultModelPrice[mjAction]
|
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||||
if !ok {
|
if !ok {
|
||||||
modelPrice = 0.1
|
modelPrice = 0.1
|
||||||
} else {
|
} else {
|
||||||
@ -423,53 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &midjResponseWithStatus.Response
|
||||||
Code: 4,
|
|
||||||
Description: "create_request_failed",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
|
midjResponse := &midjResponseWithStatus.Response
|
||||||
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
//mjToken := ""
|
|
||||||
//if c.Request.Header.Get("ApiKey") != "" {
|
|
||||||
// mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
|
|
||||||
//}
|
|
||||||
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
|
|
||||||
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
|
||||||
// print request header
|
|
||||||
log.Printf("request header: %s", req.Header)
|
|
||||||
log.Printf("request body: %s", midjRequest.Prompt)
|
|
||||||
|
|
||||||
resp, err := service.GetHttpClient().Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "do_request_failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "close_request_body_failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "close_request_body_failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var midjResponse dto.MidjourneyResponse
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if consumeQuota {
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
@ -481,7 +514,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
@ -489,41 +522,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}(c.Request.Context())
|
||||||
|
|
||||||
//if consumeQuota {
|
|
||||||
//
|
|
||||||
//}
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "read_response_body_failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "close_response_body_failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = json.Unmarshal(responseBody, &midjResponse)
|
|
||||||
log.Printf("responseBody: %s", string(responseBody))
|
|
||||||
log.Printf("midjResponse: %v", midjResponse)
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return &dto.MidjourneyResponse{
|
|
||||||
Code: 4,
|
|
||||||
Description: "unmarshal_response_body_failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||||||
//1-提交成功
|
//1-提交成功
|
||||||
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
||||||
@ -575,8 +573,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
//修改返回值
|
//修改返回值
|
||||||
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
|
||||||
responseBody = []byte(newBody)
|
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
||||||
|
responseBody = []byte(newBody)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = midjourneyTask.Insert()
|
err = midjourneyTask.Insert()
|
||||||
@ -593,21 +593,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
responseBody = []byte(newBody)
|
responseBody = []byte(newBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
//for k, v := range resp.Header {
|
||||||
c.Writer.Header().Set(k, v[0])
|
// c.Writer.Header().Set(k, v[0])
|
||||||
}
|
//}
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
_, err = io.Copy(c.Writer, bodyReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "copy_response_body_failed",
|
Description: "copy_response_body_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = bodyReader.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
@ -622,32 +623,3 @@ type taskChangeParams struct {
|
|||||||
Action string
|
Action string
|
||||||
Index int
|
Index int
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertSimpleChangeParams(content string) *taskChangeParams {
|
|
||||||
split := strings.Split(content, " ")
|
|
||||||
if len(split) != 2 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
action := strings.ToLower(split[1])
|
|
||||||
changeParams := &taskChangeParams{}
|
|
||||||
changeParams.ID = split[0]
|
|
||||||
|
|
||||||
if action[0] == 'u' {
|
|
||||||
changeParams.Action = "UPSCALE"
|
|
||||||
} else if action[0] == 'v' {
|
|
||||||
changeParams.Action = "VARIATION"
|
|
||||||
} else if action == "r" {
|
|
||||||
changeParams.Action = "REROLL"
|
|
||||||
return changeParams
|
|
||||||
} else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
index, err := strconv.Atoi(action[1:2])
|
|
||||||
if err != nil || index < 1 || index > 4 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
changeParams.Index = index
|
|
||||||
return changeParams
|
|
||||||
}
|
|
||||||
|
@ -47,6 +47,9 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
|
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
|
||||||
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||||
{
|
{
|
||||||
|
relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
|
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
|
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
|
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
|
||||||
@ -54,7 +57,9 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
|
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/notify", controller.RelayMidjourney)
|
relayMjRouter.POST("/notify", controller.RelayMidjourney)
|
||||||
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
|
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
|
||||||
|
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
|
||||||
}
|
}
|
||||||
//relayMjRouter.Use()
|
//relayMjRouter.Use()
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,20 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
|
||||||
|
return &dto.MidjourneyResponse{
|
||||||
|
Code: code,
|
||||||
|
Description: desc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
|
||||||
|
return &dto.MidjourneyResponseWithStatusCode{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Response: *MidjourneyErrorWrapper(code, desc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
|
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
|
||||||
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
text := err.Error()
|
text := err.Error()
|
||||||
|
224
service/midjourney.go
Normal file
224
service/midjourney.go
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CoverActionToModelName(mjAction string) string {
|
||||||
|
modelName := "mj_" + strings.ToLower(mjAction)
|
||||||
|
if mjAction == constant.MjActionSwapFace {
|
||||||
|
modelName = "swap_face"
|
||||||
|
}
|
||||||
|
return modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
|
||||||
|
action := ""
|
||||||
|
if relayMode == relayconstant.RelayModeMidjourneyAction {
|
||||||
|
// plus request
|
||||||
|
err := CoverPlusActionToNormalAction(midjRequest)
|
||||||
|
if err != nil {
|
||||||
|
return "", err, false
|
||||||
|
}
|
||||||
|
action = midjRequest.Action
|
||||||
|
} else {
|
||||||
|
switch relayMode {
|
||||||
|
case relayconstant.RelayModeMidjourneyImagine:
|
||||||
|
action = constant.MjActionImagine
|
||||||
|
case relayconstant.RelayModeMidjourneyDescribe:
|
||||||
|
action = constant.MjActionDescribe
|
||||||
|
case relayconstant.RelayModeMidjourneyBlend:
|
||||||
|
action = constant.MjActionBlend
|
||||||
|
case relayconstant.RelayModeMidjourneyShorten:
|
||||||
|
action = constant.MjActionShorten
|
||||||
|
case relayconstant.RelayModeMidjourneyChange:
|
||||||
|
action = midjRequest.Action
|
||||||
|
case relayconstant.RelayModeMidjourneyModal:
|
||||||
|
action = constant.MjActionModal
|
||||||
|
case relayconstant.RelayModeSwapFace:
|
||||||
|
action = constant.MjActionSwapFace
|
||||||
|
case relayconstant.RelayModeMidjourneySimpleChange:
|
||||||
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
|
if params == nil {
|
||||||
|
return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
|
||||||
|
}
|
||||||
|
action = params.Action
|
||||||
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
|
||||||
|
return "", nil, true
|
||||||
|
default:
|
||||||
|
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelName := CoverActionToModelName(action)
|
||||||
|
return modelName, nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
|
||||||
|
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
||||||
|
customId := midjRequest.CustomId
|
||||||
|
if customId == "" {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
|
||||||
|
}
|
||||||
|
splits := strings.Split(customId, "::")
|
||||||
|
var action string
|
||||||
|
if splits[1] == "JOB" {
|
||||||
|
action = splits[2]
|
||||||
|
} else {
|
||||||
|
action = splits[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if action == "" {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
||||||
|
}
|
||||||
|
if strings.Contains(action, "upsample") {
|
||||||
|
index, err := strconv.Atoi(splits[3])
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
||||||
|
}
|
||||||
|
midjRequest.Index = index
|
||||||
|
midjRequest.Action = constant.MjActionUpscale
|
||||||
|
} else if strings.Contains(action, "variation") {
|
||||||
|
midjRequest.Index = 1
|
||||||
|
if action == "variation" {
|
||||||
|
index, err := strconv.Atoi(splits[3])
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
||||||
|
}
|
||||||
|
midjRequest.Index = index
|
||||||
|
midjRequest.Action = constant.MjActionVariation
|
||||||
|
} else if action == "low_variation" {
|
||||||
|
midjRequest.Action = constant.MjActionLowVariation
|
||||||
|
} else if action == "high_variation" {
|
||||||
|
midjRequest.Action = constant.MjActionHighVariation
|
||||||
|
}
|
||||||
|
} else if strings.Contains(action, "pan") {
|
||||||
|
midjRequest.Action = constant.MjActionPan
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if strings.Contains(action, "reroll") {
|
||||||
|
midjRequest.Action = constant.MjActionReRoll
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if action == "Outpaint" {
|
||||||
|
midjRequest.Action = constant.MjActionZoom
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if action == "CustomZoom" {
|
||||||
|
midjRequest.Action = constant.MjActionCustomZoom
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if action == "Inpaint" {
|
||||||
|
midjRequest.Action = constant.MjActionInPaint
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
|
||||||
|
split := strings.Split(content, " ")
|
||||||
|
if len(split) != 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
action := strings.ToLower(split[1])
|
||||||
|
changeParams := &dto.MidjourneyRequest{}
|
||||||
|
changeParams.TaskId = split[0]
|
||||||
|
|
||||||
|
if action[0] == 'u' {
|
||||||
|
changeParams.Action = "UPSCALE"
|
||||||
|
} else if action[0] == 'v' {
|
||||||
|
changeParams.Action = "VARIATION"
|
||||||
|
} else if action == "r" {
|
||||||
|
changeParams.Action = "REROLL"
|
||||||
|
return changeParams
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
index, err := strconv.Atoi(action[1:2])
|
||||||
|
if err != nil || index < 1 || index > 4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
changeParams.Index = index
|
||||||
|
return changeParams
|
||||||
|
}
|
||||||
|
|
||||||
|
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
|
||||||
|
var nullBytes []byte
|
||||||
|
//var requestBody io.Reader
|
||||||
|
//requestBody = c.Request.Body
|
||||||
|
// read request body to json, delete accountFilter and notifyHook
|
||||||
|
var mapResult map[string]interface{}
|
||||||
|
err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
delete(mapResult, "accountFilter")
|
||||||
|
delete(mapResult, "notifyHook")
|
||||||
|
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
// make new request with mapResult
|
||||||
|
reqBody, err := json.Marshal(mapResult)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
// 使用带有超时的 context 创建新的请求
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
||||||
|
defer cancel()
|
||||||
|
resp, err := GetHttpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||||
|
}
|
||||||
|
statusCode := resp.StatusCode
|
||||||
|
//if statusCode != 200 {
|
||||||
|
// return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
|
||||||
|
//}
|
||||||
|
err = req.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||||
|
}
|
||||||
|
err = c.Request.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||||
|
}
|
||||||
|
var midjResponse dto.MidjourneyResponse
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(responseBody, &midjResponse)
|
||||||
|
log.Printf("responseBody: %s", string(responseBody))
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
||||||
|
}
|
||||||
|
//log.Printf("midjResponse: %v", midjResponse)
|
||||||
|
//for k, v := range resp.Header {
|
||||||
|
// c.Writer.Header().Set(k, v[0])
|
||||||
|
//}
|
||||||
|
return &dto.MidjourneyResponseWithStatusCode{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Response: midjResponse,
|
||||||
|
}, responseBody, nil
|
||||||
|
}
|
@ -31,10 +31,30 @@ function renderType(type) {
|
|||||||
return <Tag color="orange" size='large'>放大</Tag>;
|
return <Tag color="orange" size='large'>放大</Tag>;
|
||||||
case 'VARIATION':
|
case 'VARIATION':
|
||||||
return <Tag color="purple" size='large'>变换</Tag>;
|
return <Tag color="purple" size='large'>变换</Tag>;
|
||||||
|
case 'HIGH_VARIATION':
|
||||||
|
return <Tag color="purple" size='large'>强变换</Tag>;
|
||||||
|
case 'LOW_VARIATION':
|
||||||
|
return <Tag color="purple" size='large'>弱变换</Tag>;
|
||||||
|
case 'PAN':
|
||||||
|
return <Tag color="cyan" size='large'>平移</Tag>;
|
||||||
case 'DESCRIBE':
|
case 'DESCRIBE':
|
||||||
return <Tag color="yellow" size='large'>图生文</Tag>;
|
return <Tag color="yellow" size='large'>图生文</Tag>;
|
||||||
case 'BLEAND':
|
case 'BLEND':
|
||||||
return <Tag color="lime" size='large'>图混合</Tag>;
|
return <Tag color="lime" size='large'>图混合</Tag>;
|
||||||
|
case 'SHORTEN':
|
||||||
|
return <Tag color="pink" size='large'>缩词</Tag>;
|
||||||
|
case 'REROLL':
|
||||||
|
return <Tag color="indigo" size='large'>重绘</Tag>;
|
||||||
|
case 'INPAINT':
|
||||||
|
return <Tag color="violet" size='large'>局部重绘-提交</Tag>;
|
||||||
|
case 'ZOOM':
|
||||||
|
return <Tag color="teal" size='large'>变焦</Tag>;
|
||||||
|
case 'CUSTOM_ZOOM':
|
||||||
|
return <Tag color="teal" size='large'>自定义变焦-提交</Tag>;
|
||||||
|
case 'MODAL':
|
||||||
|
return <Tag color="green" size='large'>窗口处理</Tag>;
|
||||||
|
case 'SWAP_FACE':
|
||||||
|
return <Tag color="light-green" size='large'>换脸</Tag>;
|
||||||
default:
|
default:
|
||||||
return <Tag color="white" size='large'>未知</Tag>;
|
return <Tag color="white" size='large'>未知</Tag>;
|
||||||
}
|
}
|
||||||
@ -46,9 +66,11 @@ function renderCode(code) {
|
|||||||
case 1:
|
case 1:
|
||||||
return <Tag color="green" size='large'>已提交</Tag>;
|
return <Tag color="green" size='large'>已提交</Tag>;
|
||||||
case 21:
|
case 21:
|
||||||
return <Tag color="lime" size='large'>排队中</Tag>;
|
return <Tag color="lime" size='large'>等待中</Tag>;
|
||||||
case 22:
|
case 22:
|
||||||
return <Tag color="orange" size='large'>重复提交</Tag>;
|
return <Tag color="orange" size='large'>重复提交</Tag>;
|
||||||
|
case 0:
|
||||||
|
return <Tag color="yellow" size='large'>未提交</Tag>;
|
||||||
default:
|
default:
|
||||||
return <Tag color="white" size='large'>未知</Tag>;
|
return <Tag color="white" size='large'>未知</Tag>;
|
||||||
}
|
}
|
||||||
@ -68,6 +90,8 @@ function renderStatus(type) {
|
|||||||
return <Tag color="blue" size='large'>执行中</Tag>;
|
return <Tag color="blue" size='large'>执行中</Tag>;
|
||||||
case 'FAILURE':
|
case 'FAILURE':
|
||||||
return <Tag color="red" size='large'>失败</Tag>;
|
return <Tag color="red" size='large'>失败</Tag>;
|
||||||
|
case 'MODAL':
|
||||||
|
return <Tag color="yellow" size='large'>窗口等待</Tag>;
|
||||||
default:
|
default:
|
||||||
return <Tag color="white" size='large'>未知</Tag>;
|
return <Tag color="white" size='large'>未知</Tag>;
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
export const CHANNEL_OPTIONS = [
|
export const CHANNEL_OPTIONS = [
|
||||||
{key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
|
{key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
|
||||||
{key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
|
{key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
|
||||||
|
{key: 5, text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus'},
|
||||||
{key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'},
|
{key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'},
|
||||||
{key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'},
|
{key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'},
|
||||||
{key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'},
|
{key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'},
|
||||||
|
@ -95,6 +95,28 @@ const EditChannel = (props) => {
|
|||||||
case 26:
|
case 26:
|
||||||
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
|
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
|
||||||
break;
|
break;
|
||||||
|
case 2:
|
||||||
|
localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe'];
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
localModels = [
|
||||||
|
'swap_face',
|
||||||
|
'mj_imagine',
|
||||||
|
'mj_variation',
|
||||||
|
'mj_reroll',
|
||||||
|
'mj_blend',
|
||||||
|
'mj_upscale',
|
||||||
|
'mj_describe',
|
||||||
|
'mj_zoom',
|
||||||
|
'mj_shorten',
|
||||||
|
'mj_modal',
|
||||||
|
'mj_inpaint',
|
||||||
|
'mj_custom_zoom',
|
||||||
|
'mj_high_variation',
|
||||||
|
'mj_low_variation',
|
||||||
|
'mj_pan',
|
||||||
|
];
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
setInputs((inputs) => ({...inputs, models: localModels}));
|
setInputs((inputs) => ({...inputs, models: localModels}));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user