diff --git a/Midjourney.md b/Midjourney.md
index fe4d433..5733a11 100644
--- a/Midjourney.md
+++ b/Midjourney.md
@@ -4,288 +4,64 @@
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
+### 模型列表
+
+### midjourney-proxy支持
+
+- mj_imagine (绘图)
+- mj_variation (变换)
+- mj_reroll (重绘)
+- mj_blend (混合)
+- mj_upscale (放大)
+- mj_describe (图生文)
+
+### 仅midjourney-proxy-plus支持
+
+- mj_zoom (比例变焦)
+- mj_shorten (提示词缩短)
+- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
+- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
+- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
+- mj_high_variation (强变换)
+- mj_low_variation (弱变换)
+- mj_pan (平移)
+- swap_face (换脸)
+
```json
{
- "gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
+ "mj_modal": 0.1,
+ "mj_zoom": 0.1,
+ "mj_shorten": 0.1,
+ "mj_high_variation": 0.1,
+ "mj_low_variation": 0.1,
+ "mj_pan": 0.1,
+ "mj_inpaint": 0,
+ "mj_custom_zoom": 0,
"mj_describe": 0.05,
- "mj_upscale": 0.05
+ "mj_upscale": 0.05,
+ "swap_face": 0.05
}
```
## 渠道设置
-### 对接 midjourney-proxy
-1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
-2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
+### 对接 midjourney-proxy(plus)
+
+1.
+
+部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
+
+2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
+ ,模型选择midjourney,如果有换脸模型,可以选择swap_face
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
### 对接上游new api
-1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
-2. 地址填写上游new api的地址,例如:http://localhost:8080
-3. 密钥填写上游new api的密钥
-## 任务提交
-
-### 绘图变化
-
-**接口地址**:`/mj/submit/change`
-
-**请求方式**:`POST`
-
-**请求数据类型**:`application/json`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求示例**:
-
-```javascript
-{
- "action"
-:
- "UPSCALE",
- "index"
-:
- 1,
- "notifyHook"
-:
- "",
- "state"
-:
- "",
- "taskId"
-:
- "1320098173412546"
-}
-```
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
-| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
-| action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
-| index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
-| notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
-| state | 自定义参数 | | false | string | |
-| taskId | 任务ID | | true | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 提交结果 |
-| 201 | Created | |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|-------------------------------------------|----------------|----------------|
-| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
-| description | 描述 | string | |
-| properties | 扩展字段 | object | |
-| result | 任务ID | string | |
-
-**响应示例**:
-
-```javascript
-{
- "code"
-:
- 1,
- "description"
-:
- "提交成功",
- "properties"
-:
- {
- }
-,
- "result"
-:
- 1320098173412546
-}
-```
-
-### 提交Imagine任务
-
-**接口地址**:`/mj/submit/imagine`
-
-**请求方式**:`POST`
-
-**请求数据类型**:`application/json`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求示例**:
-
-```javascript
-{
- "base64"
-:
- "",
- "notifyHook"
-:
- "",
- "prompt"
-:
- "Cat",
- "state"
-:
- ""
-}
-```
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------------------------|-------------------------|------|-------|-------------|-------------|
-| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
-| base64 | 垫图base64 | | false | string | |
-| notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
-| prompt | 提示词 | | true | string | |
-| state | 自定义参数 | | false | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 提交结果 |
-| 201 | Created | |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|-------------------------------------------|----------------|----------------|
-| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
-| description | 描述 | string | |
-| properties | 扩展字段 | object | |
-| result | 任务ID | string | |
-
-**响应示例**:
-
-```javascript
-{
- "code"
-:
- 1,
- "description"
-:
- "提交成功",
- "properties"
-:
- {
- }
-,
- "result"
-:
- 1320098173412546
-}
-```
-
-## 任务查询
-
-### 指定ID获取任务
-
-**接口地址**:`/mj/task/{id}/fetch`
-
-**请求方式**:`GET`
-
-**请求数据类型**:`application/x-www-form-urlencoded`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------|------|------|-------|--------|--------|
-| id | 任务ID | path | false | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 任务 |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|----------------------------------------------------------|----------------|----------------|
-| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
-| description | 任务描述 | string | |
-| failReason | 失败原因 | string | |
-| finishTime | 结束时间 | integer(int64) | integer(int64) |
-| id | 任务ID | string | |
-| imageUrl | 图片url | string | |
-| progress | 任务进度 | string | |
-| prompt | 提示词 | string | |
-| promptEn | 提示词-英文 | string | |
-| startTime | 开始执行时间 | integer(int64) | integer(int64) |
-| state | 自定义参数 | string | |
-| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
-| submitTime | 提交时间 | integer(int64) | integer(int64) |
-
-**响应示例**:
-
-```javascript
-{
- "action"
-:
- "",
- "description"
-:
- "",
- "failReason"
-:
- "",
- "finishTime"
-:
- 0,
- "id"
-:
- "",
- "imageUrl"
-:
- "",
- "progress"
-:
- "",
- "prompt"
-:
- "",
- "promptEn"
-:
- "",
- "startTime"
-:
- 0,
- "state"
-:
- "",
- "status"
-:
- "",
- "submitTime"
-:
- 0
-}
-```
\ No newline at end of file
+1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
+2. 地址填写上游new api的地址,例如:http://localhost:3000
+3. 密钥填写上游new api的密钥
\ No newline at end of file
diff --git a/README.md b/README.md
index f1d18da..ce3f27f 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@
此分叉版本的主要变更如下:
1. 全新的UI界面(部分界面还待更新)
-2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持
+2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
@@ -26,6 +26,11 @@
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
+ [x] /task/list-by-condition
+ + [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
+ + [x] /mj/submit/modal
+ + [x] /mj/submit/shorten
+ + [x] /mj/task/{id}/image-seed
+ + [x] /mj/insight-face/swap (InsightFace)
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
4. 支持用key查询使用额度:
@@ -49,6 +54,7 @@
2. 智谱glm-4v,glm-4v识图
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
+5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
diff --git a/common/constants.go b/common/constants.go
index 3e67c27..9ba7d9e 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -193,7 +193,7 @@ const (
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
- ChannelTypeOpenAISB = 5
+ ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 3836e05..5e6163f 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -95,17 +95,31 @@ var ModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
}
-var ModelPrice = map[string]float64{
- "gpt-4-gizmo-*": 0.1,
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
+var DefaultModelPrice = map[string]float64{
+ "gpt-4-gizmo-*": 0.1,
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_modal": 0.1,
+ "mj_zoom": 0.1,
+ "mj_shorten": 0.1,
+ "mj_high_variation": 0.1,
+ "mj_low_variation": 0.1,
+ "mj_pan": 0.1,
+ "mj_inpaint": 0,
+ "mj_custom_zoom": 0,
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
+ "swap_face": 0.05,
}
+var ModelPrice = map[string]float64{}
+
func ModelPrice2JSONString() string {
+ if len(ModelPrice) == 0 {
+ ModelPrice = DefaultModelPrice
+ }
jsonBytes, err := json.Marshal(ModelPrice)
if err != nil {
SysError("error marshalling model price: " + err.Error())
@@ -119,6 +133,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
}
func GetModelPrice(name string, printErr bool) float64 {
+ if len(ModelPrice) == 0 {
+ ModelPrice = DefaultModelPrice
+ }
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
diff --git a/constant/midjourney.go b/constant/midjourney.go
new file mode 100644
index 0000000..3d321ca
--- /dev/null
+++ b/constant/midjourney.go
@@ -0,0 +1,42 @@
+package constant
+
+const (
+ MjErrorUnknown = 5
+ MjRequestError = 4
+)
+
+const (
+ MjActionImagine = "IMAGINE"
+ MjActionDescribe = "DESCRIBE"
+ MjActionBlend = "BLEND"
+ MjActionUpscale = "UPSCALE"
+ MjActionVariation = "VARIATION"
+ MjActionReRoll = "REROLL"
+ MjActionInPaint = "INPAINT"
+ MjActionModal = "MODAL"
+ MjActionZoom = "ZOOM"
+ MjActionCustomZoom = "CUSTOM_ZOOM"
+ MjActionShorten = "SHORTEN"
+ MjActionHighVariation = "HIGH_VARIATION"
+ MjActionLowVariation = "LOW_VARIATION"
+ MjActionPan = "PAN"
+ MjActionSwapFace = "SWAP_FACE"
+)
+
+var MidjourneyModel2Action = map[string]string{
+ "mj_imagine": MjActionImagine,
+ "mj_describe": MjActionDescribe,
+ "mj_blend": MjActionBlend,
+ "mj_upscale": MjActionUpscale,
+ "mj_variation": MjActionVariation,
+ "mj_reroll": MjActionReRoll,
+ "mj_modal": MjActionModal,
+ "mj_inpaint": MjActionInPaint,
+ "mj_zoom": MjActionZoom,
+ "mj_custom_zoom": MjActionCustomZoom,
+ "mj_shorten": MjActionShorten,
+ "mj_high_variation": MjActionHighVariation,
+ "mj_low_variation": MjActionLowVariation,
+ "mj_pan": MjActionPan,
+ "swap_face": MjActionSwapFace,
+}
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 4bcd4d4..96f82ee 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
- case common.ChannelTypeOpenAISB:
- return updateChannelOpenAISBBalance(channel)
+ //case common.ChannelTypeOpenAISB:
+ // return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 1a42270..41db4bf 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -10,145 +10,14 @@ import (
"log"
"net/http"
"one-api/common"
+ "one-api/dto"
"one-api/model"
- relay2 "one-api/relay"
"one-api/service"
"strconv"
"strings"
"time"
)
-/*func UpdateMidjourneyTask() {
- //revocer
- //imageModel := "midjourney"
- ctx := context.TODO()
- imageModel := "midjourney"
- defer func() {
- if err := recover(); err != nil {
- log.Printf("UpdateMidjourneyTask panic: %v", err)
- }
- }()
- for {
- time.Sleep(time.Duration(15) * time.Second)
- tasks := model.GetAllUnFinishTasks()
- if len(tasks) != 0 {
- common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
- for _, task := range tasks {
- common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
- midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
- if err != nil {
- common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
- task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
- task.Status = "FAILURE"
- task.Progress = "100%"
- err := task.Update()
- if err != nil {
- common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
- continue
- }
- continue
- }
- requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
- common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
-
- req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
- if err != nil {
- common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
- continue
- }
-
- // 设置超时时间
- timeout := time.Second * 5
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
-
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
-
- req.Header.Set("Content-Type", "application/json")
- //req.Header.Set("ApiKey", "Bearer midjourney-proxy")
- req.Header.Set("mj-api-secret", midjourneyChannel.Key)
- resp, err := httpClient.Do(req)
- if err != nil {
- log.Printf("UpdateMidjourneyTask error: %v", err)
- continue
- }
- responseBody, err := io.ReadAll(resp.Body)
- resp.Body.Close()
- log.Printf("responseBody: %s", string(responseBody))
- var responseItem Midjourney
- // err = json.NewDecoder(resp.Body).Decode(&responseItem)
- err = json.Unmarshal(responseBody, &responseItem)
- if err != nil {
- if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
- var responseWithoutStatus MidjourneyWithoutStatus
- var responseStatus MidjourneyStatus
- err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
- err2 := json.Unmarshal(responseBody, &responseStatus)
- if err1 == nil && err2 == nil {
- jsonData, err3 := json.Marshal(responseWithoutStatus)
- if err3 != nil {
- log.Printf("UpdateMidjourneyTask error1: %v", err3)
- continue
- }
- err4 := json.Unmarshal(jsonData, &responseStatus)
- if err4 != nil {
- log.Printf("UpdateMidjourneyTask error2: %v", err4)
- continue
- }
- responseItem.Status = strconv.Itoa(responseStatus.Status)
- } else {
- log.Printf("UpdateMidjourneyTask error3: %v", err)
- continue
- }
- } else {
- log.Printf("UpdateMidjourneyTask error4: %v", err)
- continue
- }
- }
- task.Code = 1
- task.Progress = responseItem.Progress
- task.PromptEn = responseItem.PromptEn
- task.State = responseItem.State
- task.SubmitTime = responseItem.SubmitTime
- task.StartTime = responseItem.StartTime
- task.FinishTime = responseItem.FinishTime
- task.ImageUrl = responseItem.ImageUrl
- task.Status = responseItem.Status
- task.FailReason = responseItem.FailReason
- if task.Progress != "100%" && responseItem.FailReason != "" {
- common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
- task.Progress = "100%"
- err = model.CacheUpdateUserQuota(task.UserId)
- if err != nil {
- log.Println("error update user quota cache: " + err.Error())
- } else {
- modelRatio := common.GetModelRatio(imageModel)
- groupRatio := common.GetGroupRatio("default")
- ratio := modelRatio * groupRatio
- quota := int(ratio * 1 * 1000)
- if quota != 0 {
- err := model.IncreaseUserQuota(task.UserId, quota)
- if err != nil {
- log.Println("fail to increase user quota")
- }
- logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- }
- }
-
- err = task.Update()
- if err != nil {
- log.Printf("UpdateMidjourneyTask error5: %v", err)
- }
- log.Printf("UpdateMidjourneyTask success: %v", task)
- cancel()
- }
- }
- }
-}
-*/
-
func UpdateMidjourneyTaskBulk() {
//imageModel := "midjourney"
ctx := context.TODO()
@@ -228,12 +97,16 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
+ if resp.StatusCode != http.StatusOK {
+ common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ continue
+ }
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
- var responseItems []relay2.Midjourney
+ var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
@@ -245,10 +118,16 @@ func UpdateMidjourneyTaskBulk() {
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
+
+ useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
+ // 如果时间超过一小时,且进度不是100%,则认为任务失败
+ if useTime > 3600000 && task.Progress != "100%" {
+ responseItem.FailReason = "上游任务超时(超过1小时)"
+ responseItem.Status = "FAILURE"
+ }
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
-
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
@@ -259,6 +138,15 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
+ if responseItem.Properties != nil {
+ propertiesStr, _ := json.Marshal(responseItem.Properties)
+ task.Properties = string(propertiesStr)
+ }
+ if responseItem.Buttons != nil {
+ buttonStr, _ := json.Marshal(responseItem.Buttons)
+ task.Buttons = string(buttonStr)
+ }
+
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
@@ -286,7 +174,7 @@ func UpdateMidjourneyTaskBulk() {
}
}
-func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
+func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
if oldTask.Code != 1 {
return true
}
diff --git a/controller/model.go b/controller/model.go
index 38c6c46..9a106aa 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -4,12 +4,13 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
+ "one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"one-api/relay/channel/ai360"
"one-api/relay/channel/moonshot"
- "one-api/relay/constant"
+ relayconstant "one-api/relay/constant"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -59,8 +60,8 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
- for i := 0; i < constant.APITypeDummy; i++ {
- if i == constant.APITypeAIProxyLibrary {
+ for i := 0; i < relayconstant.APITypeDummy; i++ {
+ if i == relayconstant.APITypeAIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
@@ -100,6 +101,17 @@ func init() {
Parent: nil,
})
}
+ for modelName, _ := range constant.MidjourneyModel2Action {
+ openAIModels = append(openAIModels, OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "midjourney",
+ Permission: permission,
+ Root: modelName,
+ Parent: nil,
+ })
+ }
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
diff --git a/controller/relay.go b/controller/relay.go
index 911a7c5..9f866b8 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -12,7 +12,6 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv"
- "strings"
)
func Relay(c *gin.Context) {
@@ -61,60 +60,35 @@ func Relay(c *gin.Context) {
}
func RelayMidjourney(c *gin.Context) {
- relayMode := relayconstant.RelayModeUnknown
- if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
- relayMode = relayconstant.RelayModeMidjourneyImagine
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
- relayMode = relayconstant.RelayModeMidjourneyBlend
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
- relayMode = relayconstant.RelayModeMidjourneyDescribe
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
- relayMode = relayconstant.RelayModeMidjourneyNotify
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
- relayMode = relayconstant.RelayModeMidjourneyChange
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
- relayMode = relayconstant.RelayModeMidjourneyChange
- } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
- relayMode = relayconstant.RelayModeMidjourneyTaskFetch
- } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
- relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
- }
-
+ relayMode := c.GetInt("relay_mode")
var err *dto.MidjourneyResponse
switch relayMode {
case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
+ case relayconstant.RelayModeMidjourneyTaskImageSeed:
+ err = relay.RelayMidjourneyTaskImageSeed(c)
+ case relayconstant.RelayModeSwapFace:
+ err = relay.RelaySwapFace(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
- retryTimesStr := c.Query("retry")
- retryTimes, _ := strconv.Atoi(retryTimesStr)
- if retryTimesStr == "" {
- retryTimes = common.RetryTimes
- }
- if retryTimes > 0 {
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
- } else {
- if err.Code == 30 {
- err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
- }
- c.JSON(429, gin.H{
- "error": fmt.Sprintf("%s %s", err.Description, err.Result),
- "type": "upstream_error",
- })
+ statusCode := http.StatusBadRequest
+ if err.Code == 30 {
+ err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+ statusCode = http.StatusTooManyRequests
}
+ c.JSON(statusCode, gin.H{
+ "description": fmt.Sprintf("%s %s", err.Description, err.Result),
+ "type": "upstream_error",
+ "code": err.Code,
+ })
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
- //if shouldDisableChannel(&err.Error) {
- // channelId := c.GetInt("channel_id")
- // channelName := c.GetString("channel_name")
- // disableChannel(channelId, channelName, err.Result)
- //};''''''''''''''''''''''''''''''''
}
}
diff --git a/dto/midjourney.go b/dto/midjourney.go
index 4c67909..c675f7e 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -1,7 +1,21 @@
package dto
+//type SimpleMjRequest struct {
+// Prompt string `json:"prompt"`
+// CustomId string `json:"customId"`
+// Action string `json:"action"`
+// Content string `json:"content"`
+//}
+
+type SwapFaceRequest struct {
+ SourceBase64 string `json:"sourceBase64"`
+ TargetBase64 string `json:"targetBase64"`
+}
+
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
+ CustomId string `json:"customId"`
+ BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
@@ -9,6 +23,7 @@ type MidjourneyRequest struct {
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
+ MaskBase64 string `json:"maskBase64"`
}
type MidjourneyResponse struct {
@@ -17,3 +32,64 @@ type MidjourneyResponse struct {
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
+
+type MidjourneyResponseWithStatusCode struct {
+ StatusCode int `json:"statusCode"`
+ Response MidjourneyResponse
+}
+
+type MidjourneyDto struct {
+ MjId string `json:"id"`
+ Action string `json:"action"`
+ CustomId string `json:"customId"`
+ BotType string `json:"botType"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"promptEn"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submitTime"`
+ StartTime int64 `json:"startTime"`
+ FinishTime int64 `json:"finishTime"`
+ ImageUrl string `json:"imageUrl"`
+ Status string `json:"status"`
+ Progress string `json:"progress"`
+ FailReason string `json:"failReason"`
+ Buttons any `json:"buttons"`
+ MaskBase64 string `json:"maskBase64"`
+ Properties *Properties `json:"properties"`
+}
+
+type MidjourneyStatus struct {
+ Status int `json:"status"`
+}
+type MidjourneyWithoutStatus struct {
+ Id int `json:"id"`
+ Code int `json:"code"`
+ UserId int `json:"user_id" gorm:"index"`
+ Action string `json:"action"`
+ MjId string `json:"mj_id" gorm:"index"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"prompt_en"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submit_time"`
+ StartTime int64 `json:"start_time"`
+ FinishTime int64 `json:"finish_time"`
+ ImageUrl string `json:"image_url"`
+ Progress string `json:"progress"`
+ FailReason string `json:"fail_reason"`
+ ChannelId int `json:"channel_id"`
+}
+
+type ActionButton struct {
+ CustomId any `json:"customId"`
+ Emoji any `json:"emoji"`
+ Label any `json:"label"`
+ Type any `json:"type"`
+ Style any `json:"style"`
+}
+
+type Properties struct {
+ FinalPrompt string `json:"finalPrompt"`
+ FinalZhPrompt string `json:"finalZhPrompt"`
+}
diff --git a/middleware/auth.go b/middleware/auth.go
index ef774f6..4b865c2 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) {
}
token, err := model.ValidateUserToken(key)
if err != nil {
- abortWithMessage(c, http.StatusUnauthorized, err.Error())
+ abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
- abortWithMessage(c, http.StatusInternalServerError, err.Error())
+ abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
- abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
c.Set("id", token.UserId)
@@ -125,17 +125,11 @@ func TokenAuth() func(c *gin.Context) {
} else {
c.Set("token_model_limit_enabled", false)
}
- requestURL := c.Request.URL.String()
- consumeQuota := true
- if strings.HasPrefix(requestURL, "/v1/models") {
- consumeQuota = false
- }
- c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
} else {
- abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
}
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 1ca43dd..ed457a3 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -4,7 +4,11 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/constant"
+ "one-api/dto"
"one-api/model"
+ relayconstant "one-api/relay/constant"
+ "one-api/service"
"strconv"
"strings"
@@ -23,32 +27,59 @@ func Distribute() func(c *gin.Context) {
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
- abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
- abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
if channel.Status != common.ChannelStatusEnabled {
- abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
+ shouldSelectChannel := true
// Select a channel for the user
var modelRequest ModelRequest
var err error
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
- // Midjourney
- if modelRequest.Model == "" {
- modelRequest.Model = "midjourney"
+ relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
+ relayMode == relayconstant.RelayModeMidjourneyNotify ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
+ shouldSelectChannel = false
+ } else {
+ midjourneyRequest := dto.MidjourneyRequest{}
+ err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
+ if err != nil {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
+ return
+ }
+ midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
+ if mjErr != nil {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
+ return
+ }
+ if midjourneyModel == "" {
+ if !success {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
+ return
+ } else {
+ // task fetch, task fetch by condition, notify
+ shouldSelectChannel = false
+ }
+ }
+ modelRequest.Model = midjourneyModel
}
+ c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
- abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -87,60 +118,61 @@ func Distribute() func(c *gin.Context) {
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
- abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
// token model limit is empty, all models are not allowed
- abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
}
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
-
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
- if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
- // 如果错误,但是渠道不为空,说明是数据库一致性问题
- if channel != nil {
- common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
- message = "数据库一致性已被破坏,请联系管理员"
+ if shouldSelectChannel {
+ channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+ if err != nil {
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+ // 如果错误,但是渠道不为空,说明是数据库一致性问题
+ if channel != nil {
+ common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+ message = "数据库一致性已被破坏,请联系管理员"
+ }
+ // 如果错误,而且渠道为空,说明是没有可用渠道
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
+ return
+ }
+ if channel == nil {
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
+ return
+ }
+ c.Set("channel", channel.Type)
+ c.Set("channel_id", channel.Id)
+ c.Set("channel_name", channel.Name)
+ ban := true
+ // parse *int to bool
+ if channel.AutoBan != nil && *channel.AutoBan == 0 {
+ ban = false
+ }
+ c.Set("auto_ban", ban)
+ c.Set("model_mapping", channel.GetModelMapping())
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+ c.Set("base_url", channel.GetBaseURL())
+ // TODO: api_version统一
+ switch channel.Type {
+ case common.ChannelTypeAzure:
+ c.Set("api_version", channel.Other)
+ case common.ChannelTypeXunfei:
+ c.Set("api_version", channel.Other)
+ //case common.ChannelTypeAIProxyLibrary:
+ // c.Set("library_id", channel.Other)
+ case common.ChannelTypeGemini:
+ c.Set("api_version", channel.Other)
+ case common.ChannelTypeAli:
+ c.Set("plugin", channel.Other)
}
- // 如果错误,而且渠道为空,说明是没有可用渠道
- abortWithMessage(c, http.StatusServiceUnavailable, message)
- return
}
- if channel == nil {
- abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
- return
- }
- }
- c.Set("channel", channel.Type)
- c.Set("channel_id", channel.Id)
- c.Set("channel_name", channel.Name)
- ban := true
- // parse *int to bool
- if channel.AutoBan != nil && *channel.AutoBan == 0 {
- ban = false
- }
- c.Set("auto_ban", ban)
- c.Set("model_mapping", channel.GetModelMapping())
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- c.Set("base_url", channel.GetBaseURL())
- // TODO: api_version统一
- switch channel.Type {
- case common.ChannelTypeAzure:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeXunfei:
- c.Set("api_version", channel.Other)
- //case common.ChannelTypeAIProxyLibrary:
- // c.Set("library_id", channel.Other)
- case common.ChannelTypeGemini:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeAli:
- c.Set("plugin", channel.Other)
}
c.Next()
}
diff --git a/middleware/utils.go b/middleware/utils.go
index 021002d..43801c1 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -5,7 +5,7 @@ import (
"one-api/common"
)
-func abortWithMessage(c *gin.Context, statusCode int, message string) {
+func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
common.LogError(c.Request.Context(), message)
}
+
+func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
+ c.JSON(statusCode, gin.H{
+ "description": description,
+ "type": "new_api_error",
+ "code": code,
+ })
+ c.Abort()
+ common.LogError(c.Request.Context(), description)
+}
diff --git a/model/ability.go b/model/ability.go
index 7a81cc2..b79978d 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -147,7 +147,12 @@ func FixAbility() (int, error) {
return 0, err
}
var channels []Channel
- err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
+
+ if len(abilityChannelIds) == 0 {
+ err = DB.Find(&channels).Error
+ } else {
+ err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
+ }
if err != nil {
return 0, err
}
diff --git a/model/midjourney.go b/model/midjourney.go
index 0ef2e55..dd065a3 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -19,6 +19,8 @@ type Midjourney struct {
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
+ Buttons string `json:"buttons"`
+ Properties string `json:"properties"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index beea7dc..1790c57 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -17,10 +17,15 @@ const (
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
+ RelayModeMidjourneyTaskImageSeed
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
+ RelayModeMidjourneyAction
+ RelayModeMidjourneyModal
+ RelayModeMidjourneyShorten
+ RelayModeSwapFace
)
func Path2RelayMode(path string) int {
@@ -48,3 +53,39 @@ func Path2RelayMode(path string) int {
}
return relayMode
}
+
+func Path2RelayModeMidjourney(path string) int {
+ relayMode := RelayModeUnknown
+ if strings.HasPrefix(path, "/mj/submit/action") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyAction
+ } else if strings.HasPrefix(path, "/mj/submit/modal") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyModal
+ } else if strings.HasPrefix(path, "/mj/submit/shorten") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyShorten
+ } else if strings.HasPrefix(path, "/mj/insight-face/swap") {
+ // midjourney plus
+ relayMode = RelayModeSwapFace
+ } else if strings.HasPrefix(path, "/mj/submit/imagine") {
+ relayMode = RelayModeMidjourneyImagine
+ } else if strings.HasPrefix(path, "/mj/submit/blend") {
+ relayMode = RelayModeMidjourneyBlend
+ } else if strings.HasPrefix(path, "/mj/submit/describe") {
+ relayMode = RelayModeMidjourneyDescribe
+ } else if strings.HasPrefix(path, "/mj/notify") {
+ relayMode = RelayModeMidjourneyNotify
+ } else if strings.HasPrefix(path, "/mj/submit/change") {
+ relayMode = RelayModeMidjourneyChange
+ } else if strings.HasPrefix(path, "/mj/submit/simple-change") {
+ relayMode = RelayModeMidjourneyChange
+ } else if strings.HasSuffix(path, "/fetch") {
+ relayMode = RelayModeMidjourneyTaskFetch
+ } else if strings.HasSuffix(path, "/image-seed") {
+ relayMode = RelayModeMidjourneyTaskImageSeed
+ } else if strings.HasSuffix(path, "/list-by-condition") {
+ relayMode = RelayModeMidjourneyTaskFetchByCondition
+ }
+ return relayMode
+}
diff --git a/relay/relay-image.go b/relay/relay-image.go
index 3065496..aabe4ba 100644
--- a/relay/relay-image.go
+++ b/relay/relay-image.go
@@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
- consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
startTime := time.Now()
var imageRequest dto.ImageRequest
- if consumeQuota {
- err := common.UnmarshalBodyReusable(c, &imageRequest)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
- }
+ err := common.UnmarshalBodyReusable(c, &imageRequest)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageRequest.Model == "" {
@@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
- if consumeQuota && userQuota-quota < 0 {
+ if userQuota-quota < 0 {
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
@@ -176,46 +173,42 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
var textResponse dto.ImageResponse
defer func(ctx context.Context) {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
- if consumeQuota {
- if resp.StatusCode != http.StatusOK {
- return
- }
- err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
- if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
- }
- err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- channelId := c.GetInt("channel_id")
- model.UpdateChannelUsedQuota(channelId, quota)
- }
+ if resp.StatusCode != http.StatusOK {
+ return
+ }
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.SysError("error update user quota cache: " + err.Error())
+ }
+ if quota != 0 {
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ channelId := c.GetInt("channel_id")
+ model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
- if consumeQuota {
- responseBody, err := io.ReadAll(resp.Body)
+ responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
- }
- err = json.Unmarshal(responseBody, &textResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
- }
-
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
+ err = resp.Body.Close()
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ }
+ err = json.Unmarshal(responseBody, &textResponse)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ }
+
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index b2b9926..35353b4 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -9,6 +9,7 @@ import (
"log"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/model"
relayconstant "one-api/relay/constant"
@@ -20,53 +21,6 @@ import (
"github.com/gin-gonic/gin"
)
-type Midjourney struct {
- MjId string `json:"id"`
- Action string `json:"action"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"promptEn"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submitTime"`
- StartTime int64 `json:"startTime"`
- FinishTime int64 `json:"finishTime"`
- ImageUrl string `json:"imageUrl"`
- Status string `json:"status"`
- Progress string `json:"progress"`
- FailReason string `json:"failReason"`
-}
-
-type MidjourneyStatus struct {
- Status int `json:"status"`
-}
-type MidjourneyWithoutStatus struct {
- Id int `json:"id"`
- Code int `json:"code"`
- UserId int `json:"user_id" gorm:"index"`
- Action string `json:"action"`
- MjId string `json:"mj_id" gorm:"index"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"prompt_en"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submit_time"`
- StartTime int64 `json:"start_time"`
- FinishTime int64 `json:"finish_time"`
- ImageUrl string `json:"image_url"`
- Progress string `json:"progress"`
- FailReason string `json:"fail_reason"`
- ChannelId int `json:"channel_id"`
-}
-
-var DefaultModelPrice = map[string]float64{
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
-}
-
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByOnlyMJId(taskId)
@@ -108,7 +62,7 @@ func RelayMidjourneyImage(c *gin.Context) {
}
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
- var midjRequest Midjourney
+ var midjRequest dto.MidjourneyDto
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &dto.MidjourneyResponse{
@@ -147,7 +101,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
return nil
}
-func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
+func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
@@ -167,9 +121,182 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
+ if originTask.Buttons != "" {
+ var buttons []dto.ActionButton
+ err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
+ if err == nil {
+ midjourneyTask.Buttons = buttons
+ }
+ }
+ if originTask.Properties != "" {
+ var properties dto.Properties
+ err := json.Unmarshal([]byte(originTask.Properties), &properties)
+ if err == nil {
+ midjourneyTask.Properties = &properties
+ }
+ }
return
}
+func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
+ startTime := time.Now().UnixNano() / int64(time.Millisecond)
+ tokenId := c.GetInt("token_id")
+ userId := c.GetInt("id")
+ group := c.GetString("group")
+ channelId := c.GetInt("channel_id")
+ var swapFaceRequest dto.SwapFaceRequest
+ err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+ if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
+ }
+ modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
+ modelPrice := common.GetModelPrice(modelName, true)
+ // 如果没有配置价格,则使用默认价格
+ if modelPrice == -1 {
+ defaultPrice, ok := common.DefaultModelPrice[modelName]
+ if !ok {
+ modelPrice = 0.1
+ } else {
+ modelPrice = defaultPrice
+ }
+ }
+ groupRatio := common.GetGroupRatio(group)
+ ratio := modelPrice * groupRatio
+ userQuota, err := model.CacheGetUserQuota(userId)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: err.Error(),
+ }
+ }
+ quota := int(ratio * common.QuotaPerUnit)
+
+ if userQuota-quota < 0 {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "quota_not_enough",
+ }
+ }
+ requestURL := c.Request.URL.String()
+ baseURL := c.GetString("base_url")
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+ mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
+ if err != nil {
+ return &mjResp.Response
+ }
+ defer func(ctx context.Context) {
+ if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.SysError("error update user quota cache: " + err.Error())
+ }
+ if quota != 0 {
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ channelId := c.GetInt("channel_id")
+ model.UpdateChannelUsedQuota(channelId, quota)
+ }
+ }
+ }(c.Request.Context())
+ midjResponse := &mjResp.Response
+ midjourneyTask := &model.Midjourney{
+ UserId: userId,
+ Code: midjResponse.Code,
+ Action: constant.MjActionSwapFace,
+ MjId: midjResponse.Result,
+ Prompt: "InsightFace",
+ PromptEn: "",
+ Description: midjResponse.Description,
+ State: "",
+ SubmitTime: startTime,
+ StartTime: time.Now().UnixNano() / int64(time.Millisecond),
+ FinishTime: 0,
+ ImageUrl: "",
+ Status: "",
+ Progress: "0%",
+ FailReason: "",
+ ChannelId: c.GetInt("channel_id"),
+ Quota: quota,
+ }
+ err = midjourneyTask.Insert()
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
+ }
+ c.Writer.WriteHeader(mjResp.StatusCode)
+ respBody, err := json.Marshal(midjResponse)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+ }
+ return nil
+}
+
+func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
+ taskId := c.Param("id")
+ userId := c.GetInt("id")
+ originTask := model.GetByMJId(userId, taskId)
+ if originTask == nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+ }
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+ }
+ if channel.Status != common.ChannelStatusEnabled {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+ }
+ c.Set("channel_id", originTask.ChannelId)
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+ requestURL := c.Request.URL.String()
+ fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
+ midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
+ if err != nil {
+ return &midjResponseWithStatus.Response
+ }
+ //defer func(ctx context.Context) {
+ // err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+ // if err != nil {
+ // common.SysError("error consuming token remain quota: " + err.Error())
+ // }
+ // err = model.CacheUpdateUserQuota(userId)
+ // if err != nil {
+ // common.SysError("error update user quota cache: " + err.Error())
+ // }
+ // if quota != 0 {
+ // tokenName := c.GetString("token_name")
+ // logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
+ // model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+ // model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ // channelId := c.GetInt("channel_id")
+ // model.UpdateChannelUsedQuota(channelId, quota)
+ // }
+ //}(c.Request.Context())
+ midjResponse := &midjResponseWithStatus.Response
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
+ respBody, err := json.Marshal(midjResponse)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+ }
+ return nil
+}
+
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
userId := c.GetInt("id")
var err error
@@ -184,7 +311,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "task_no_found",
}
}
- midjourneyTask := getMidjourneyTaskModel(c, originTask)
+ midjourneyTask := coverMidjourneyTaskDto(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &dto.MidjourneyResponse{
@@ -203,16 +330,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "do_request_failed",
}
}
- var tasks []Midjourney
+ var tasks []dto.MidjourneyDto
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
- midjourneyTask := getMidjourneyTaskModel(c, originTask)
+ midjourneyTask := coverMidjourneyTaskDto(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
- tasks = make([]Midjourney, 0)
+ tasks = make([]dto.MidjourneyDto, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
@@ -235,170 +362,115 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
return nil
}
-const (
- // type 1 根据 mode 价格不同
- MJSubmitActionImagine = "IMAGINE"
- MJSubmitActionVariation = "VARIATION" //变换
- MJSubmitActionBlend = "BLEND" //混图
-
- MJSubmitActionReroll = "REROLL" //重新生成
- // type 2 固定价格
- MJSubmitActionDescribe = "DESCRIBE"
- MJSubmitActionUpscale = "UPSCALE" // 放大
-)
-
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
- imageModel := "midjourney"
tokenId := c.GetInt("token_id")
- channelType := c.GetInt("channel")
+ //channelType := c.GetInt("channel")
userId := c.GetInt("id")
- consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
+ consumeQuota := true
var midjRequest dto.MidjourneyRequest
- if consumeQuota {
- err := common.UnmarshalBodyReusable(c, &midjRequest)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "bind_request_body_failed",
- }
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+
+ if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
+ mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
+ if mjErr != nil {
+ return mjErr
}
+ relayMode = relayconstant.RelayModeMidjourneyChange
}
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "prompt_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
}
- midjRequest.Action = "IMAGINE"
+ midjRequest.Action = constant.MjActionImagine
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
- midjRequest.Action = "DESCRIBE"
+ midjRequest.Action = constant.MjActionDescribe
+ } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
+ midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
- midjRequest.Action = "BLEND"
+ midjRequest.Action = constant.MjActionBlend
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "taskId_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "action_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
} else if midjRequest.Index == 0 {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "index_can_only_be_1_2_3_4",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "content_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
}
- params := convertSimpleChangeParams(midjRequest.Content)
+ params := service.ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "content_parse_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
}
- mjId = params.ID
+ mjId = params.TaskId
midjRequest.Action = params.Action
+ } else if relayMode == relayconstant.RelayModeMidjourneyModal {
+ //if midjRequest.MaskBase64 == "" {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
+ //}
+ mjId = midjRequest.TaskId
+ midjRequest.Action = constant.MjActionModal
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "task_no_found",
- }
- } else if originTask.Action == "UPSCALE" {
- //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "upscale_task_can_not_be_change",
- }
- } else if originTask.Status != "SUCCESS" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "task_status_is_not_success",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
+ } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
- channel, err := model.GetChannelById(originTask.ChannelId, false)
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "channel_not_found",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+ }
+ if channel.Status != common.ChannelStatusEnabled {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
- log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+ log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
+
+ //if channelType == common.ChannelTypeMidjourneyPlus {
+ // // plus
+ //} else {
+ // // 普通版渠道
+ //
+ //}
}
- // map model name
- modelMapping := c.GetString("model_mapping")
- isModelMapped := false
- if modelMapping != "" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "unmarshal_model_mapping_failed",
- }
- }
- if modelMap[imageModel] != "" {
- imageModel = modelMap[imageModel]
- isModelMapped = true
- }
+ if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
+ consumeQuota = false
}
- baseURL := common.ChannelBaseURLs[channelType]
+ //baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
- if c.GetString("base_url") != "" {
- baseURL = c.GetString("base_url")
- }
+ baseURL := c.GetString("base_url")
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- log.Printf("fullRequestURL: %s", fullRequestURL)
- var requestBody io.Reader
- if isModelMapped {
- jsonStr, err := json.Marshal(midjRequest)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "marshal_text_request_failed",
- }
- }
- requestBody = bytes.NewBuffer(jsonStr)
- } else {
- requestBody = c.Request.Body
- }
- mjAction := "mj_" + strings.ToLower(midjRequest.Action)
- modelPrice := common.GetModelPrice(mjAction, true)
+ modelName := service.CoverActionToModelName(midjRequest.Action)
+ modelPrice := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if modelPrice == -1 {
- defaultPrice, ok := DefaultModelPrice[mjAction]
+ defaultPrice, ok := common.DefaultModelPrice[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -423,53 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "create_request_failed",
- }
+ return &midjResponseWithStatus.Response
}
- //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
-
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- //mjToken := ""
- //if c.Request.Header.Get("ApiKey") != "" {
- // mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
- //}
- //req.Header.Set("ApiKey", "Bearer midjourney-proxy")
- req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
- // print request header
- log.Printf("request header: %s", req.Header)
- log.Printf("request body: %s", midjRequest.Prompt)
-
- resp, err := service.GetHttpClient().Do(req)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "do_request_failed",
- }
- }
-
- err = req.Body.Close()
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "close_request_body_failed",
- }
- }
- err = c.Request.Body.Close()
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "close_request_body_failed",
- }
- }
- var midjResponse dto.MidjourneyResponse
+ midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
- if consumeQuota {
+ if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
@@ -481,7 +514,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
@@ -489,41 +522,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}(c.Request.Context())
- //if consumeQuota {
- //
- //}
- responseBody, err := io.ReadAll(resp.Body)
-
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "read_response_body_failed",
- }
- }
- err = resp.Body.Close()
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "close_response_body_failed",
- }
- }
-
- err = json.Unmarshal(responseBody, &midjResponse)
- log.Printf("responseBody: %s", string(responseBody))
- log.Printf("midjResponse: %v", midjResponse)
- if resp.StatusCode != 200 {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
- }
- }
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "unmarshal_response_body_failed",
- }
- }
-
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
@@ -575,8 +573,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
//修改返回值
- newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
- responseBody = []byte(newBody)
+ if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
+ newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
+ responseBody = []byte(newBody)
+ }
}
err = midjourneyTask.Insert()
@@ -593,21 +593,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody)
}
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
+ //for k, v := range resp.Header {
+ // c.Writer.Header().Set(k, v[0])
+ //}
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
+ _, err = io.Copy(c.Writer, bodyReader)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
- err = resp.Body.Close()
+ err = bodyReader.Close()
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
@@ -622,32 +623,3 @@ type taskChangeParams struct {
Action string
Index int
}
-
-func convertSimpleChangeParams(content string) *taskChangeParams {
- split := strings.Split(content, " ")
- if len(split) != 2 {
- return nil
- }
-
- action := strings.ToLower(split[1])
- changeParams := &taskChangeParams{}
- changeParams.ID = split[0]
-
- if action[0] == 'u' {
- changeParams.Action = "UPSCALE"
- } else if action[0] == 'v' {
- changeParams.Action = "VARIATION"
- } else if action == "r" {
- changeParams.Action = "REROLL"
- return changeParams
- } else {
- return nil
- }
-
- index, err := strconv.Atoi(action[1:2])
- if err != nil || index < 1 || index > 4 {
- return nil
- }
- changeParams.Index = index
- return changeParams
-}
diff --git a/router/relay-router.go b/router/relay-router.go
index 6a30a5a..4addee0 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -47,6 +47,9 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
+ relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
@@ -54,7 +57,9 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
+ relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
+ relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
}
//relayMjRouter.Use()
}
diff --git a/service/error.go b/service/error.go
index 303bcf7..424be5d 100644
--- a/service/error.go
+++ b/service/error.go
@@ -11,6 +11,20 @@ import (
"strings"
)
+func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
+ return &dto.MidjourneyResponse{
+ Code: code,
+ Description: desc,
+ }
+}
+
+func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
+ return &dto.MidjourneyResponseWithStatusCode{
+ StatusCode: statusCode,
+ Response: *MidjourneyErrorWrapper(code, desc),
+ }
+}
+
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
diff --git a/service/midjourney.go b/service/midjourney.go
new file mode 100644
index 0000000..7c47cd6
--- /dev/null
+++ b/service/midjourney.go
@@ -0,0 +1,224 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "io"
+ "log"
+ "net/http"
+ "one-api/constant"
+ "one-api/dto"
+ relayconstant "one-api/relay/constant"
+ "strconv"
+ "strings"
+ "time"
+)
+
+func CoverActionToModelName(mjAction string) string {
+ modelName := "mj_" + strings.ToLower(mjAction)
+ if mjAction == constant.MjActionSwapFace {
+ modelName = "swap_face"
+ }
+ return modelName
+}
+
+func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
+ action := ""
+ if relayMode == relayconstant.RelayModeMidjourneyAction {
+ // plus request
+ err := CoverPlusActionToNormalAction(midjRequest)
+ if err != nil {
+ return "", err, false
+ }
+ action = midjRequest.Action
+ } else {
+ switch relayMode {
+ case relayconstant.RelayModeMidjourneyImagine:
+ action = constant.MjActionImagine
+ case relayconstant.RelayModeMidjourneyDescribe:
+ action = constant.MjActionDescribe
+ case relayconstant.RelayModeMidjourneyBlend:
+ action = constant.MjActionBlend
+ case relayconstant.RelayModeMidjourneyShorten:
+ action = constant.MjActionShorten
+ case relayconstant.RelayModeMidjourneyChange:
+ action = midjRequest.Action
+ case relayconstant.RelayModeMidjourneyModal:
+ action = constant.MjActionModal
+ case relayconstant.RelayModeSwapFace:
+ action = constant.MjActionSwapFace
+ case relayconstant.RelayModeMidjourneySimpleChange:
+ params := ConvertSimpleChangeParams(midjRequest.Content)
+ if params == nil {
+ return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
+ }
+ action = params.Action
+ case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
+ return "", nil, true
+ default:
+ return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
+ }
+ }
+ modelName := CoverActionToModelName(action)
+ return modelName, nil, true
+}
+
+func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
+ // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
+ customId := midjRequest.CustomId
+ if customId == "" {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
+ }
+ splits := strings.Split(customId, "::")
+ var action string
+ if splits[1] == "JOB" {
+ action = splits[2]
+ } else {
+ action = splits[1]
+ }
+
+ if action == "" {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ }
+ if strings.Contains(action, "upsample") {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionUpscale
+ } else if strings.Contains(action, "variation") {
+ midjRequest.Index = 1
+ if action == "variation" {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionVariation
+ } else if action == "low_variation" {
+ midjRequest.Action = constant.MjActionLowVariation
+ } else if action == "high_variation" {
+ midjRequest.Action = constant.MjActionHighVariation
+ }
+ } else if strings.Contains(action, "pan") {
+ midjRequest.Action = constant.MjActionPan
+ midjRequest.Index = 1
+ } else if strings.Contains(action, "reroll") {
+ midjRequest.Action = constant.MjActionReRoll
+ midjRequest.Index = 1
+ } else if action == "Outpaint" {
+ midjRequest.Action = constant.MjActionZoom
+ midjRequest.Index = 1
+ } else if action == "CustomZoom" {
+ midjRequest.Action = constant.MjActionCustomZoom
+ midjRequest.Index = 1
+ } else if action == "Inpaint" {
+ midjRequest.Action = constant.MjActionInPaint
+ midjRequest.Index = 1
+ } else {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
+ }
+ return nil
+}
+
+func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
+ split := strings.Split(content, " ")
+ if len(split) != 2 {
+ return nil
+ }
+
+ action := strings.ToLower(split[1])
+ changeParams := &dto.MidjourneyRequest{}
+ changeParams.TaskId = split[0]
+
+ if action[0] == 'u' {
+ changeParams.Action = "UPSCALE"
+ } else if action[0] == 'v' {
+ changeParams.Action = "VARIATION"
+ } else if action == "r" {
+ changeParams.Action = "REROLL"
+ return changeParams
+ } else {
+ return nil
+ }
+
+ index, err := strconv.Atoi(action[1:2])
+ if err != nil || index < 1 || index > 4 {
+ return nil
+ }
+ changeParams.Index = index
+ return changeParams
+}
+
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+ var nullBytes []byte
+ //var requestBody io.Reader
+ //requestBody = c.Request.Body
+ // read request body to json, delete accountFilter and notifyHook
+ var mapResult map[string]interface{}
+ err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ delete(mapResult, "accountFilter")
+ delete(mapResult, "notifyHook")
+ //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ // make new request with mapResult
+ reqBody, err := json.Marshal(mapResult)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ // 使用带有超时的 context 创建新的请求
+ req = req.WithContext(ctx)
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
+ defer cancel()
+ resp, err := GetHttpClient().Do(req)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ statusCode := resp.StatusCode
+ //if statusCode != 200 {
+ // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
+ //}
+ err = req.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+ }
+ err = c.Request.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+ }
+ var midjResponse dto.MidjourneyResponse
+
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
+ }
+
+ err = json.Unmarshal(responseBody, &midjResponse)
+ log.Printf("responseBody: %s", string(responseBody))
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
+ }
+ //log.Printf("midjResponse: %v", midjResponse)
+ //for k, v := range resp.Header {
+ // c.Writer.Header().Set(k, v[0])
+ //}
+ return &dto.MidjourneyResponseWithStatusCode{
+ StatusCode: statusCode,
+ Response: midjResponse,
+ }, responseBody, nil
+}
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 1f71208..90c55f1 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -31,10 +31,30 @@ function renderType(type) {
return 放大;
case 'VARIATION':
return 变换;
+ case 'HIGH_VARIATION':
+ return 强变换;
+ case 'LOW_VARIATION':
+ return 弱变换;
+ case 'PAN':
+ return 平移;
case 'DESCRIBE':
return 图生文;
- case 'BLEAND':
+ case 'BLEND':
return 图混合;
+ case 'SHORTEN':
+ return 缩词;
+ case 'REROLL':
+ return 重绘;
+ case 'INPAINT':
+ return 局部重绘-提交;
+ case 'ZOOM':
+ return 变焦;
+ case 'CUSTOM_ZOOM':
+ return 自定义变焦-提交;
+ case 'MODAL':
+ return 窗口处理;
+ case 'SWAP_FACE':
+ return 换脸;
default:
return 未知;
}
@@ -46,9 +66,11 @@ function renderCode(code) {
case 1:
return 已提交;
case 21:
- return 排队中;
+ return 等待中;
case 22:
return 重复提交;
+ case 0:
+ return 未提交;
default:
return 未知;
}
@@ -68,6 +90,8 @@ function renderStatus(type) {
return 执行中;
case 'FAILURE':
return 失败;
+ case 'MODAL':
+ return 窗口等待;
default:
return 未知;
}
diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js
index a641a02..bb18d0d 100644
--- a/web/src/constants/channel.constants.js
+++ b/web/src/constants/channel.constants.js
@@ -1,6 +1,7 @@
export const CHANNEL_OPTIONS = [
{key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
{key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
+ {key: 5, text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus'},
{key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'},
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index ddfb744..2b84011 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -95,6 +95,28 @@ const EditChannel = (props) => {
case 26:
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
break;
+ case 2:
+ localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe'];
+ break;
+ case 5:
+ localModels = [
+ 'swap_face',
+ 'mj_imagine',
+ 'mj_variation',
+ 'mj_reroll',
+ 'mj_blend',
+ 'mj_upscale',
+ 'mj_describe',
+ 'mj_zoom',
+ 'mj_shorten',
+ 'mj_modal',
+ 'mj_inpaint',
+ 'mj_custom_zoom',
+ 'mj_high_variation',
+ 'mj_low_variation',
+ 'mj_pan',
+ ];
+ break;
}
setInputs((inputs) => ({...inputs, models: localModels}));
}