diff --git a/Midjourney.md b/Midjourney.md index fe4d433..5733a11 100644 --- a/Midjourney.md +++ b/Midjourney.md @@ -4,288 +4,64 @@ ## 模型价格设置(在设置-运营设置-模型固定价格设置中设置) +### 模型列表 + +### midjourney-proxy支持 + +- mj_imagine (绘图) +- mj_variation (变换) +- mj_reroll (重绘) +- mj_blend (混合) +- mj_upscale (放大) +- mj_describe (图生文) + +### 仅midjourney-proxy-plus支持 + +- mj_zoom (比例变焦) +- mj_shorten (提示词缩短) +- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加) +- mj_inpaint (局部重绘提交,必须和mj_modal一同添加) +- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加) +- mj_high_variation (强变换) +- mj_low_variation (弱变换) +- mj_pan (平移) +- swap_face (换脸) + ```json { - "gpt-4-gizmo-*": 0.1, "mj_imagine": 0.1, "mj_variation": 0.1, "mj_reroll": 0.1, "mj_blend": 0.1, + "mj_modal": 0.1, + "mj_zoom": 0.1, + "mj_shorten": 0.1, + "mj_high_variation": 0.1, + "mj_low_variation": 0.1, + "mj_pan": 0.1, + "mj_inpaint": 0, + "mj_custom_zoom": 0, "mj_describe": 0.05, - "mj_upscale": 0.05 + "mj_upscale": 0.05, + "swap_face": 0.05 } ``` ## 渠道设置 -### 对接 midjourney-proxy -1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy) -2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney +### 对接 midjourney-proxy(plus) + +1. + +部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy) + +2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus** + ,模型选择midjourney,如果有换脸模型,可以选择swap_face 3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080 4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填 ### 对接上游new api -1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney -2. 地址填写上游new api的地址,例如:http://localhost:8080 -3. 密钥填写上游new api的密钥 -## 任务提交 - -### 绘图变化 - -**接口地址**:`/mj/submit/change` - -**请求方式**:`POST` - -**请求数据类型**:`application/json` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求示例**: - -```javascript -{ - "action" -: - "UPSCALE", - "index" -: - 1, - "notifyHook" -: - "", - "state" -: - "", - "taskId" -: - "1320098173412546" -} -``` - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------| -| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 | -|   action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | | -|   index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | | -|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | | -|   state | 自定义参数 | | false | string | | -|   taskId | 任务ID | | true | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 提交结果 | -| 201 | Created | | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|-------------------------------------------|----------------|----------------| -| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) | -| description | 描述 | string | | -| properties | 扩展字段 | object | | -| result | 任务ID | string | | - -**响应示例**: - -```javascript -{ - "code" -: - 1, - "description" -: - "提交成功", - "properties" -: - { - } -, - "result" -: - 1320098173412546 -} -``` - -### 提交Imagine任务 - -**接口地址**:`/mj/submit/imagine` - -**请求方式**:`POST` - -**请求数据类型**:`application/json` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求示例**: - -```javascript -{ - "base64" -: - "", - "notifyHook" -: - "", - "prompt" -: - "Cat", - "state" -: - "" -} -``` - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------------------------|-------------------------|------|-------|-------------|-------------| -| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 | -|   base64 | 垫图base64 | | false | string | | -|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | | -|   prompt | 提示词 | | true | string | | -|   state | 自定义参数 | | false | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 提交结果 | -| 201 | Created | | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|-------------------------------------------|----------------|----------------| -| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) | -| description | 描述 | string | | -| properties | 扩展字段 | object | | -| result | 任务ID | string | | - -**响应示例**: - -```javascript -{ - "code" -: - 1, - "description" -: - "提交成功", - "properties" -: - { - } -, - "result" -: - 1320098173412546 -} -``` - -## 任务查询 - -### 指定ID获取任务 - -**接口地址**:`/mj/task/{id}/fetch` - -**请求方式**:`GET` - -**请求数据类型**:`application/x-www-form-urlencoded` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------|------|------|-------|--------|--------| -| id | 任务ID | path | false | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 任务 | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|----------------------------------------------------------|----------------|----------------| -| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | | -| description | 任务描述 | string | | -| failReason | 失败原因 | string | | -| finishTime | 结束时间 | integer(int64) | integer(int64) | -| id | 任务ID | string | | -| imageUrl | 图片url | string | | -| progress | 任务进度 | string | | -| prompt | 提示词 | string | | -| promptEn | 提示词-英文 | string | | -| startTime | 开始执行时间 | integer(int64) | integer(int64) | -| state | 自定义参数 | string | | -| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | | -| submitTime | 提交时间 | integer(int64) | integer(int64) | - -**响应示例**: - -```javascript -{ - "action" -: - "", - "description" -: - "", - "failReason" -: - "", - "finishTime" -: - 0, - "id" -: - "", - "imageUrl" -: - "", - "progress" -: - "", - "prompt" -: - "", - "promptEn" -: - "", - "startTime" -: - 0, - "state" -: - "", - "status" -: - "", - "submitTime" -: - 0 -} -``` \ No newline at end of file +1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face +2. 地址填写上游new api的地址,例如:http://localhost:3000 +3. 密钥填写上游new api的密钥 \ No newline at end of file diff --git a/README.md b/README.md index f1d18da..ce3f27f 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ 此分叉版本的主要变更如下: 1. 全新的UI界面(部分界面还待更新) -2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持 +2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持 + [x] /mj/submit/imagine + [x] /mj/submit/change + [x] /mj/submit/blend @@ -26,6 +26,11 @@ + [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**) + [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址) + [x] /task/list-by-condition + + [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同) + + [x] /mj/submit/modal + + [x] /mj/submit/shorten + + [x] /mj/task/{id}/image-seed + + [x] /mj/insight-face/swap (InsightFace) 3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口: + [x] 易支付 4. 支持用key查询使用额度: @@ -49,6 +54,7 @@ 2. 智谱glm-4v,glm-4v识图 3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229) 4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改 +5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 diff --git a/common/constants.go b/common/constants.go index 3e67c27..9ba7d9e 100644 --- a/common/constants.go +++ b/common/constants.go @@ -193,7 +193,7 @@ const ( ChannelTypeMidjourney = 2 ChannelTypeAzure = 3 ChannelTypeOllama = 4 - ChannelTypeOpenAISB = 5 + ChannelTypeMidjourneyPlus = 5 ChannelTypeOpenAIMax = 6 ChannelTypeOhMyGPT = 7 ChannelTypeCustom = 8 diff --git a/common/model-ratio.go b/common/model-ratio.go index 3836e05..5e6163f 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -95,17 +95,31 @@ var ModelRatio = map[string]float64{ "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 } -var ModelPrice = map[string]float64{ - "gpt-4-gizmo-*": 0.1, - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_describe": 0.05, - "mj_upscale": 0.05, +var DefaultModelPrice = map[string]float64{ + "gpt-4-gizmo-*": 0.1, + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_modal": 0.1, + "mj_zoom": 0.1, + "mj_shorten": 0.1, + "mj_high_variation": 0.1, + "mj_low_variation": 0.1, + "mj_pan": 0.1, + "mj_inpaint": 0, + "mj_custom_zoom": 0, + "mj_describe": 0.05, + "mj_upscale": 0.05, + "swap_face": 0.05, } +var ModelPrice = map[string]float64{} + func ModelPrice2JSONString() string { + if len(ModelPrice) == 0 { + ModelPrice = DefaultModelPrice + } jsonBytes, err := json.Marshal(ModelPrice) if err != nil { SysError("error marshalling model price: " + err.Error()) @@ -119,6 +133,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error { } func GetModelPrice(name string, printErr bool) float64 { + if len(ModelPrice) == 0 { + ModelPrice = DefaultModelPrice + } if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } diff --git a/constant/midjourney.go b/constant/midjourney.go new file mode 100644 index 0000000..3d321ca --- /dev/null +++ b/constant/midjourney.go @@ -0,0 +1,42 @@ +package constant + +const ( + MjErrorUnknown = 5 + MjRequestError = 4 +) + +const ( + MjActionImagine = "IMAGINE" + MjActionDescribe = "DESCRIBE" + MjActionBlend = "BLEND" + MjActionUpscale = "UPSCALE" + MjActionVariation = "VARIATION" + MjActionReRoll = "REROLL" + MjActionInPaint = "INPAINT" + MjActionModal = "MODAL" + MjActionZoom = "ZOOM" + MjActionCustomZoom = "CUSTOM_ZOOM" + MjActionShorten = "SHORTEN" + MjActionHighVariation = "HIGH_VARIATION" + MjActionLowVariation = "LOW_VARIATION" + MjActionPan = "PAN" + MjActionSwapFace = "SWAP_FACE" +) + +var MidjourneyModel2Action = map[string]string{ + "mj_imagine": MjActionImagine, + "mj_describe": MjActionDescribe, + "mj_blend": MjActionBlend, + "mj_upscale": MjActionUpscale, + "mj_variation": MjActionVariation, + "mj_reroll": MjActionReRoll, + "mj_modal": MjActionModal, + "mj_inpaint": MjActionInPaint, + "mj_zoom": MjActionZoom, + "mj_custom_zoom": MjActionCustomZoom, + "mj_shorten": MjActionShorten, + "mj_high_variation": MjActionHighVariation, + "mj_low_variation": MjActionLowVariation, + "mj_pan": MjActionPan, + "swap_face": MjActionSwapFace, +} diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 4bcd4d4..96f82ee 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return 0, errors.New("尚未实现") case common.ChannelTypeCustom: baseURL = channel.GetBaseURL() - case common.ChannelTypeOpenAISB: - return updateChannelOpenAISBBalance(channel) + //case common.ChannelTypeOpenAISB: + // return updateChannelOpenAISBBalance(channel) case common.ChannelTypeAIProxy: return updateChannelAIProxyBalance(channel) case common.ChannelTypeAPI2GPT: diff --git a/controller/midjourney.go b/controller/midjourney.go index 1a42270..41db4bf 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -10,145 +10,14 @@ import ( "log" "net/http" "one-api/common" + "one-api/dto" "one-api/model" - relay2 "one-api/relay" "one-api/service" "strconv" "strings" "time" ) -/*func UpdateMidjourneyTask() { - //revocer - //imageModel := "midjourney" - ctx := context.TODO() - imageModel := "midjourney" - defer func() { - if err := recover(); err != nil { - log.Printf("UpdateMidjourneyTask panic: %v", err) - } - }() - for { - time.Sleep(time.Duration(15) * time.Second) - tasks := model.GetAllUnFinishTasks() - if len(tasks) != 0 { - common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) - for _, task := range tasks { - common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task)) - midjourneyChannel, err := model.GetChannelById(task.ChannelId, true) - if err != nil { - common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err)) - task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId) - task.Status = "FAILURE" - task.Progress = "100%" - err := task.Update() - if err != nil { - common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) - continue - } - continue - } - requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId) - common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl)) - - req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte(""))) - if err != nil { - common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err)) - continue - } - - // 设置超时时间 - timeout := time.Second * 5 - ctx, cancel := context.WithTimeout(context.Background(), timeout) - - // 使用带有超时的 context 创建新的请求 - req = req.WithContext(ctx) - - req.Header.Set("Content-Type", "application/json") - //req.Header.Set("ApiKey", "Bearer midjourney-proxy") - req.Header.Set("mj-api-secret", midjourneyChannel.Key) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf("UpdateMidjourneyTask error: %v", err) - continue - } - responseBody, err := io.ReadAll(resp.Body) - resp.Body.Close() - log.Printf("responseBody: %s", string(responseBody)) - var responseItem Midjourney - // err = json.NewDecoder(resp.Body).Decode(&responseItem) - err = json.Unmarshal(responseBody, &responseItem) - if err != nil { - if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") { - var responseWithoutStatus MidjourneyWithoutStatus - var responseStatus MidjourneyStatus - err1 := json.Unmarshal(responseBody, &responseWithoutStatus) - err2 := json.Unmarshal(responseBody, &responseStatus) - if err1 == nil && err2 == nil { - jsonData, err3 := json.Marshal(responseWithoutStatus) - if err3 != nil { - log.Printf("UpdateMidjourneyTask error1: %v", err3) - continue - } - err4 := json.Unmarshal(jsonData, &responseStatus) - if err4 != nil { - log.Printf("UpdateMidjourneyTask error2: %v", err4) - continue - } - responseItem.Status = strconv.Itoa(responseStatus.Status) - } else { - log.Printf("UpdateMidjourneyTask error3: %v", err) - continue - } - } else { - log.Printf("UpdateMidjourneyTask error4: %v", err) - continue - } - } - task.Code = 1 - task.Progress = responseItem.Progress - task.PromptEn = responseItem.PromptEn - task.State = responseItem.State - task.SubmitTime = responseItem.SubmitTime - task.StartTime = responseItem.StartTime - task.FinishTime = responseItem.FinishTime - task.ImageUrl = responseItem.ImageUrl - task.Status = responseItem.Status - task.FailReason = responseItem.FailReason - if task.Progress != "100%" && responseItem.FailReason != "" { - common.LogWarn(task.MjId + " 构建失败," + task.FailReason) - task.Progress = "100%" - err = model.CacheUpdateUserQuota(task.UserId) - if err != nil { - log.Println("error update user quota cache: " + err.Error()) - } else { - modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio("default") - ratio := modelRatio * groupRatio - quota := int(ratio * 1 * 1000) - if quota != 0 { - err := model.IncreaseUserQuota(task.UserId, quota) - if err != nil { - log.Println("fail to increase user quota") - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } - } - - err = task.Update() - if err != nil { - log.Printf("UpdateMidjourneyTask error5: %v", err) - } - log.Printf("UpdateMidjourneyTask success: %v", task) - cancel() - } - } - } -} -*/ - func UpdateMidjourneyTaskBulk() { //imageModel := "midjourney" ctx := context.TODO() @@ -228,12 +97,16 @@ func UpdateMidjourneyTaskBulk() { common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } + if resp.StatusCode != http.StatusOK { + common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + continue + } responseBody, err := io.ReadAll(resp.Body) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } - var responseItems []relay2.Midjourney + var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) @@ -245,10 +118,16 @@ func UpdateMidjourneyTaskBulk() { for _, responseItem := range responseItems { task := taskM[responseItem.MjId] + + useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime + // 如果时间超过一小时,且进度不是100%,则认为任务失败 + if useTime > 3600000 && task.Progress != "100%" { + responseItem.FailReason = "上游任务超时(超过1小时)" + responseItem.Status = "FAILURE" + } if !checkMjTaskNeedUpdate(task, responseItem) { continue } - task.Code = 1 task.Progress = responseItem.Progress task.PromptEn = responseItem.PromptEn @@ -259,6 +138,15 @@ func UpdateMidjourneyTaskBulk() { task.ImageUrl = responseItem.ImageUrl task.Status = responseItem.Status task.FailReason = responseItem.FailReason + if responseItem.Properties != nil { + propertiesStr, _ := json.Marshal(responseItem.Properties) + task.Properties = string(propertiesStr) + } + if responseItem.Buttons != nil { + buttonStr, _ := json.Marshal(responseItem.Buttons) + task.Buttons = string(buttonStr) + } + if task.Progress != "100%" && responseItem.FailReason != "" { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" @@ -286,7 +174,7 @@ func UpdateMidjourneyTaskBulk() { } } -func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool { +func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool { if oldTask.Code != 1 { return true } diff --git a/controller/model.go b/controller/model.go index 38c6c46..9a106aa 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" + "one-api/constant" "one-api/dto" "one-api/model" "one-api/relay" "one-api/relay/channel/ai360" "one-api/relay/channel/moonshot" - "one-api/relay/constant" + relayconstant "one-api/relay/constant" ) // https://platform.openai.com/docs/api-reference/models/list @@ -59,8 +60,8 @@ func init() { IsBlocking: false, }) // https://platform.openai.com/docs/models/model-endpoint-compatibility - for i := 0; i < constant.APITypeDummy; i++ { - if i == constant.APITypeAIProxyLibrary { + for i := 0; i < relayconstant.APITypeDummy; i++ { + if i == relayconstant.APITypeAIProxyLibrary { continue } adaptor := relay.GetAdaptor(i) @@ -100,6 +101,17 @@ func init() { Parent: nil, }) } + for modelName, _ := range constant.MidjourneyModel2Action { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "midjourney", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model diff --git a/controller/relay.go b/controller/relay.go index 911a7c5..9f866b8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -12,7 +12,6 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "strconv" - "strings" ) func Relay(c *gin.Context) { @@ -61,60 +60,35 @@ func Relay(c *gin.Context) { } func RelayMidjourney(c *gin.Context) { - relayMode := relayconstant.RelayModeUnknown - if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { - relayMode = relayconstant.RelayModeMidjourneyImagine - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") { - relayMode = relayconstant.RelayModeMidjourneyBlend - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") { - relayMode = relayconstant.RelayModeMidjourneyDescribe - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") { - relayMode = relayconstant.RelayModeMidjourneyNotify - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") { - relayMode = relayconstant.RelayModeMidjourneyChange - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") { - relayMode = relayconstant.RelayModeMidjourneyChange - } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") { - relayMode = relayconstant.RelayModeMidjourneyTaskFetch - } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") { - relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition - } - + relayMode := c.GetInt("relay_mode") var err *dto.MidjourneyResponse switch relayMode { case relayconstant.RelayModeMidjourneyNotify: err = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: err = relay.RelayMidjourneyTask(c, relayMode) + case relayconstant.RelayModeMidjourneyTaskImageSeed: + err = relay.RelayMidjourneyTaskImageSeed(c) + case relayconstant.RelayModeSwapFace: + err = relay.RelaySwapFace(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } //err = relayMidjourneySubmit(c, relayMode) log.Println(err) if err != nil { - retryTimesStr := c.Query("retry") - retryTimes, _ := strconv.Atoi(retryTimesStr) - if retryTimesStr == "" { - retryTimes = common.RetryTimes - } - if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) - } else { - if err.Code == 30 { - err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" - } - c.JSON(429, gin.H{ - "error": fmt.Sprintf("%s %s", err.Description, err.Result), - "type": "upstream_error", - }) + statusCode := http.StatusBadRequest + if err.Code == 30 { + err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + statusCode = http.StatusTooManyRequests } + c.JSON(statusCode, gin.H{ + "description": fmt.Sprintf("%s %s", err.Description, err.Result), + "type": "upstream_error", + "code": err.Code, + }) channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) - //if shouldDisableChannel(&err.Error) { - // channelId := c.GetInt("channel_id") - // channelName := c.GetString("channel_name") - // disableChannel(channelId, channelName, err.Result) - //};'''''''''''''''''''''''''''''''' } } diff --git a/dto/midjourney.go b/dto/midjourney.go index 4c67909..c675f7e 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -1,7 +1,21 @@ package dto +//type SimpleMjRequest struct { +// Prompt string `json:"prompt"` +// CustomId string `json:"customId"` +// Action string `json:"action"` +// Content string `json:"content"` +//} + +type SwapFaceRequest struct { + SourceBase64 string `json:"sourceBase64"` + TargetBase64 string `json:"targetBase64"` +} + type MidjourneyRequest struct { Prompt string `json:"prompt"` + CustomId string `json:"customId"` + BotType string `json:"botType"` NotifyHook string `json:"notifyHook"` Action string `json:"action"` Index int `json:"index"` @@ -9,6 +23,7 @@ type MidjourneyRequest struct { TaskId string `json:"taskId"` Base64Array []string `json:"base64Array"` Content string `json:"content"` + MaskBase64 string `json:"maskBase64"` } type MidjourneyResponse struct { @@ -17,3 +32,64 @@ type MidjourneyResponse struct { Properties interface{} `json:"properties"` Result string `json:"result"` } + +type MidjourneyResponseWithStatusCode struct { + StatusCode int `json:"statusCode"` + Response MidjourneyResponse +} + +type MidjourneyDto struct { + MjId string `json:"id"` + Action string `json:"action"` + CustomId string `json:"customId"` + BotType string `json:"botType"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + ImageUrl string `json:"imageUrl"` + Status string `json:"status"` + Progress string `json:"progress"` + FailReason string `json:"failReason"` + Buttons any `json:"buttons"` + MaskBase64 string `json:"maskBase64"` + Properties *Properties `json:"properties"` +} + +type MidjourneyStatus struct { + Status int `json:"status"` +} +type MidjourneyWithoutStatus struct { + Id int `json:"id"` + Code int `json:"code"` + UserId int `json:"user_id" gorm:"index"` + Action string `json:"action"` + MjId string `json:"mj_id" gorm:"index"` + Prompt string `json:"prompt"` + PromptEn string `json:"prompt_en"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + ImageUrl string `json:"image_url"` + Progress string `json:"progress"` + FailReason string `json:"fail_reason"` + ChannelId int `json:"channel_id"` +} + +type ActionButton struct { + CustomId any `json:"customId"` + Emoji any `json:"emoji"` + Label any `json:"label"` + Type any `json:"type"` + Style any `json:"style"` +} + +type Properties struct { + FinalPrompt string `json:"finalPrompt"` + FinalZhPrompt string `json:"finalZhPrompt"` +} diff --git a/middleware/auth.go b/middleware/auth.go index ef774f6..4b865c2 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) { } token, err := model.ValidateUserToken(key) if err != nil { - abortWithMessage(c, http.StatusUnauthorized, err.Error()) + abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { - abortWithMessage(c, http.StatusInternalServerError, err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) return } if !userEnabled { - abortWithMessage(c, http.StatusForbidden, "用户已被封禁") + abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") return } c.Set("id", token.UserId) @@ -125,17 +125,11 @@ func TokenAuth() func(c *gin.Context) { } else { c.Set("token_model_limit_enabled", false) } - requestURL := c.Request.URL.String() - consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { - consumeQuota = false - } - c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) } else { - abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") + abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 1ca43dd..ed457a3 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -4,7 +4,11 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" + "one-api/dto" "one-api/model" + relayconstant "one-api/relay/constant" + "one-api/service" "strconv" "strings" @@ -23,32 +27,59 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } channel, err = model.GetChannelById(id, true) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } if channel.Status != common.ChannelStatusEnabled { - abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") + abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { + shouldSelectChannel := true // Select a channel for the user var modelRequest ModelRequest var err error if strings.HasPrefix(c.Request.URL.Path, "/mj") { - // Midjourney - if modelRequest.Model == "" { - modelRequest.Model = "midjourney" + relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) + if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || + relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || + relayMode == relayconstant.RelayModeMidjourneyNotify || + relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { + shouldSelectChannel = false + } else { + midjourneyRequest := dto.MidjourneyRequest{} + err = common.UnmarshalBodyReusable(c, &midjourneyRequest) + if err != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) + return + } + midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) + if mjErr != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) + return + } + if midjourneyModel == "" { + if !success { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") + return + } else { + // task fetch, task fetch by condition, notify + shouldSelectChannel = false + } + } + modelRequest.Model = midjourneyModel } + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) return } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { @@ -87,60 +118,61 @@ func Distribute() func(c *gin.Context) { } if tokenModelLimit != nil { if _, ok := tokenModelLimit[modelRequest.Model]; !ok { - abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) return } } else { // token model limit is empty, all models are not allowed - abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") return } } userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) - - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) - // 如果错误,但是渠道不为空,说明是数据库一致性问题 - if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) - message = "数据库一致性已被破坏,请联系管理员" + if shouldSelectChannel { + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + if err != nil { + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + // 如果错误,但是渠道不为空,说明是数据库一致性问题 + if channel != nil { + common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + message = "数据库一致性已被破坏,请联系管理员" + } + // 如果错误,而且渠道为空,说明是没有可用渠道 + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) + return + } + if channel == nil { + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) + return + } + c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) + ban := true + // parse *int to bool + if channel.AutoBan != nil && *channel.AutoBan == 0 { + ban = false + } + c.Set("auto_ban", ban) + c.Set("model_mapping", channel.GetModelMapping()) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set("base_url", channel.GetBaseURL()) + // TODO: api_version统一 + switch channel.Type { + case common.ChannelTypeAzure: + c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + //case common.ChannelTypeAIProxyLibrary: + // c.Set("library_id", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } - // 如果错误,而且渠道为空,说明是没有可用渠道 - abortWithMessage(c, http.StatusServiceUnavailable, message) - return } - if channel == nil { - abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) - return - } - } - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } - c.Set("auto_ban", ban) - c.Set("model_mapping", channel.GetModelMapping()) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) - // TODO: api_version统一 - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) } c.Next() } diff --git a/middleware/utils.go b/middleware/utils.go index 021002d..43801c1 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -5,7 +5,7 @@ import ( "one-api/common" ) -func abortWithMessage(c *gin.Context, statusCode int, message string) { +func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), @@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { c.Abort() common.LogError(c.Request.Context(), message) } + +func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { + c.JSON(statusCode, gin.H{ + "description": description, + "type": "new_api_error", + "code": code, + }) + c.Abort() + common.LogError(c.Request.Context(), description) +} diff --git a/model/ability.go b/model/ability.go index 7a81cc2..b79978d 100644 --- a/model/ability.go +++ b/model/ability.go @@ -147,7 +147,12 @@ func FixAbility() (int, error) { return 0, err } var channels []Channel - err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error + + if len(abilityChannelIds) == 0 { + err = DB.Find(&channels).Error + } else { + err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error + } if err != nil { return 0, err } diff --git a/model/midjourney.go b/model/midjourney.go index 0ef2e55..dd065a3 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -19,6 +19,8 @@ type Midjourney struct { FailReason string `json:"fail_reason"` ChannelId int `json:"channel_id"` Quota int `json:"quota"` + Buttons string `json:"buttons"` + Properties string `json:"properties"` } // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index beea7dc..1790c57 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -17,10 +17,15 @@ const ( RelayModeMidjourneySimpleChange RelayModeMidjourneyNotify RelayModeMidjourneyTaskFetch + RelayModeMidjourneyTaskImageSeed RelayModeMidjourneyTaskFetchByCondition RelayModeAudioSpeech RelayModeAudioTranscription RelayModeAudioTranslation + RelayModeMidjourneyAction + RelayModeMidjourneyModal + RelayModeMidjourneyShorten + RelayModeSwapFace ) func Path2RelayMode(path string) int { @@ -48,3 +53,39 @@ func Path2RelayMode(path string) int { } return relayMode } + +func Path2RelayModeMidjourney(path string) int { + relayMode := RelayModeUnknown + if strings.HasPrefix(path, "/mj/submit/action") { + // midjourney plus + relayMode = RelayModeMidjourneyAction + } else if strings.HasPrefix(path, "/mj/submit/modal") { + // midjourney plus + relayMode = RelayModeMidjourneyModal + } else if strings.HasPrefix(path, "/mj/submit/shorten") { + // midjourney plus + relayMode = RelayModeMidjourneyShorten + } else if strings.HasPrefix(path, "/mj/insight-face/swap") { + // midjourney plus + relayMode = RelayModeSwapFace + } else if strings.HasPrefix(path, "/mj/submit/imagine") { + relayMode = RelayModeMidjourneyImagine + } else if strings.HasPrefix(path, "/mj/submit/blend") { + relayMode = RelayModeMidjourneyBlend + } else if strings.HasPrefix(path, "/mj/submit/describe") { + relayMode = RelayModeMidjourneyDescribe + } else if strings.HasPrefix(path, "/mj/notify") { + relayMode = RelayModeMidjourneyNotify + } else if strings.HasPrefix(path, "/mj/submit/change") { + relayMode = RelayModeMidjourneyChange + } else if strings.HasPrefix(path, "/mj/submit/simple-change") { + relayMode = RelayModeMidjourneyChange + } else if strings.HasSuffix(path, "/fetch") { + relayMode = RelayModeMidjourneyTaskFetch + } else if strings.HasSuffix(path, "/image-seed") { + relayMode = RelayModeMidjourneyTaskImageSeed + } else if strings.HasSuffix(path, "/list-by-condition") { + relayMode = RelayModeMidjourneyTaskFetchByCondition + } + return relayMode +} diff --git a/relay/relay-image.go b/relay/relay-image.go index 3065496..aabe4ba 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") startTime := time.Now() var imageRequest dto.ImageRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &imageRequest) - if err != nil { - return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &imageRequest) + if err != nil { + return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if imageRequest.Model == "" { @@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N - if consumeQuota && userQuota-quota < 0 { + if userQuota-quota < 0 { return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } @@ -176,46 +173,42 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC var textResponse dto.ImageResponse defer func(ctx context.Context) { useTimeSeconds := time.Now().Unix() - startTime.Unix() - if consumeQuota { - if resp.StatusCode != http.StatusOK { - return - } - err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } + if resp.StatusCode != http.StatusOK { + return + } + err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) + responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index b2b9926..35353b4 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -9,6 +9,7 @@ import ( "log" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" relayconstant "one-api/relay/constant" @@ -20,53 +21,6 @@ import ( "github.com/gin-gonic/gin" ) -type Midjourney struct { - MjId string `json:"id"` - Action string `json:"action"` - Prompt string `json:"prompt"` - PromptEn string `json:"promptEn"` - Description string `json:"description"` - State string `json:"state"` - SubmitTime int64 `json:"submitTime"` - StartTime int64 `json:"startTime"` - FinishTime int64 `json:"finishTime"` - ImageUrl string `json:"imageUrl"` - Status string `json:"status"` - Progress string `json:"progress"` - FailReason string `json:"failReason"` -} - -type MidjourneyStatus struct { - Status int `json:"status"` -} -type MidjourneyWithoutStatus struct { - Id int `json:"id"` - Code int `json:"code"` - UserId int `json:"user_id" gorm:"index"` - Action string `json:"action"` - MjId string `json:"mj_id" gorm:"index"` - Prompt string `json:"prompt"` - PromptEn string `json:"prompt_en"` - Description string `json:"description"` - State string `json:"state"` - SubmitTime int64 `json:"submit_time"` - StartTime int64 `json:"start_time"` - FinishTime int64 `json:"finish_time"` - ImageUrl string `json:"image_url"` - Progress string `json:"progress"` - FailReason string `json:"fail_reason"` - ChannelId int `json:"channel_id"` -} - -var DefaultModelPrice = map[string]float64{ - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_describe": 0.05, - "mj_upscale": 0.05, -} - func RelayMidjourneyImage(c *gin.Context) { taskId := c.Param("id") midjourneyTask := model.GetByOnlyMJId(taskId) @@ -108,7 +62,7 @@ func RelayMidjourneyImage(c *gin.Context) { } func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { - var midjRequest Midjourney + var midjRequest dto.MidjourneyDto err := common.UnmarshalBodyReusable(c, &midjRequest) if err != nil { return &dto.MidjourneyResponse{ @@ -147,7 +101,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { return nil } -func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) { +func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { midjourneyTask.MjId = originTask.MjId midjourneyTask.Progress = originTask.Progress midjourneyTask.PromptEn = originTask.PromptEn @@ -167,9 +121,182 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.Action = originTask.Action midjourneyTask.Description = originTask.Description midjourneyTask.Prompt = originTask.Prompt + if originTask.Buttons != "" { + var buttons []dto.ActionButton + err := json.Unmarshal([]byte(originTask.Buttons), &buttons) + if err == nil { + midjourneyTask.Buttons = buttons + } + } + if originTask.Properties != "" { + var properties dto.Properties + err := json.Unmarshal([]byte(originTask.Properties), &properties) + if err == nil { + midjourneyTask.Properties = &properties + } + } return } +func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { + startTime := time.Now().UnixNano() / int64(time.Millisecond) + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + group := c.GetString("group") + channelId := c.GetInt("channel_id") + var swapFaceRequest dto.SwapFaceRequest + err := common.UnmarshalBodyReusable(c, &swapFaceRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") + } + if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") + } + modelName := service.CoverActionToModelName(constant.MjActionSwapFace) + modelPrice := common.GetModelPrice(modelName, true) + // 如果没有配置价格,则使用默认价格 + if modelPrice == -1 { + defaultPrice, ok := common.DefaultModelPrice[modelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + groupRatio := common.GetGroupRatio(group) + ratio := modelPrice * groupRatio + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return &dto.MidjourneyResponse{ + Code: 4, + Description: err.Error(), + } + } + quota := int(ratio * common.QuotaPerUnit) + + if userQuota-quota < 0 { + return &dto.MidjourneyResponse{ + Code: 4, + Description: "quota_not_enough", + } + } + requestURL := c.Request.URL.String() + baseURL := c.GetString("base_url") + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) + if err != nil { + return &mjResp.Response + } + defer func(ctx context.Context) { + if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { + err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + } + }(c.Request.Context()) + midjResponse := &mjResp.Response + midjourneyTask := &model.Midjourney{ + UserId: userId, + Code: midjResponse.Code, + Action: constant.MjActionSwapFace, + MjId: midjResponse.Result, + Prompt: "InsightFace", + PromptEn: "", + Description: midjResponse.Description, + State: "", + SubmitTime: startTime, + StartTime: time.Now().UnixNano() / int64(time.Millisecond), + FinishTime: 0, + ImageUrl: "", + Status: "", + Progress: "0%", + FailReason: "", + ChannelId: c.GetInt("channel_id"), + Quota: quota, + } + err = midjourneyTask.Insert() + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed") + } + c.Writer.WriteHeader(mjResp.StatusCode) + respBody, err := json.Marshal(midjResponse) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") + } + return nil +} + +func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { + taskId := c.Param("id") + userId := c.GetInt("id") + originTask := model.GetByMJId(userId, taskId) + if originTask == nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") + } + channel, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") + } + if channel.Status != common.ChannelStatusEnabled { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") + } + c.Set("channel_id", originTask.ChannelId) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + + requestURL := c.Request.URL.String() + fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) + midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) + if err != nil { + return &midjResponseWithStatus.Response + } + //defer func(ctx context.Context) { + // err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) + // if err != nil { + // common.SysError("error consuming token remain quota: " + err.Error()) + // } + // err = model.CacheUpdateUserQuota(userId) + // if err != nil { + // common.SysError("error update user quota cache: " + err.Error()) + // } + // if quota != 0 { + // tokenName := c.GetString("token_name") + // logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action) + // model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) + // model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + // channelId := c.GetInt("channel_id") + // model.UpdateChannelUsedQuota(channelId, quota) + // } + //}(c.Request.Context()) + midjResponse := &midjResponseWithStatus.Response + c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) + respBody, err := json.Marshal(midjResponse) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") + } + return nil +} + func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { userId := c.GetInt("id") var err error @@ -184,7 +311,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse Description: "task_no_found", } } - midjourneyTask := getMidjourneyTaskModel(c, originTask) + midjourneyTask := coverMidjourneyTaskDto(c, originTask) respBody, err = json.Marshal(midjourneyTask) if err != nil { return &dto.MidjourneyResponse{ @@ -203,16 +330,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse Description: "do_request_failed", } } - var tasks []Midjourney + var tasks []dto.MidjourneyDto if len(condition.IDs) != 0 { originTasks := model.GetByMJIds(userId, condition.IDs) for _, originTask := range originTasks { - midjourneyTask := getMidjourneyTaskModel(c, originTask) + midjourneyTask := coverMidjourneyTaskDto(c, originTask) tasks = append(tasks, midjourneyTask) } } if tasks == nil { - tasks = make([]Midjourney, 0) + tasks = make([]dto.MidjourneyDto, 0) } respBody, err = json.Marshal(tasks) if err != nil { @@ -235,170 +362,115 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse return nil } -const ( - // type 1 根据 mode 价格不同 - MJSubmitActionImagine = "IMAGINE" - MJSubmitActionVariation = "VARIATION" //变换 - MJSubmitActionBlend = "BLEND" //混图 - - MJSubmitActionReroll = "REROLL" //重新生成 - // type 2 固定价格 - MJSubmitActionDescribe = "DESCRIBE" - MJSubmitActionUpscale = "UPSCALE" // 放大 -) - func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - imageModel := "midjourney" tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") + //channelType := c.GetInt("channel") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") channelId := c.GetInt("channel_id") + consumeQuota := true var midjRequest dto.MidjourneyRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &midjRequest) - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "bind_request_body_failed", - } + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") + } + + if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 + mjErr := service.CoverPlusActionToNormalAction(&midjRequest) + if mjErr != nil { + return mjErr } + relayMode = relayconstant.RelayModeMidjourneyChange } if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "prompt_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } - midjRequest.Action = "IMAGINE" + midjRequest.Action = constant.MjActionImagine } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 - midjRequest.Action = "DESCRIBE" + midjRequest.Action = constant.MjActionDescribe + } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only + midjRequest.Action = constant.MjActionShorten } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 - midjRequest.Action = "BLEND" + midjRequest.Action = constant.MjActionBlend } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" if relayMode == relayconstant.RelayModeMidjourneyChange { if midjRequest.TaskId == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "taskId_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") } else if midjRequest.Action == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "action_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required") } else if midjRequest.Index == 0 { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "index_can_only_be_1_2_3_4", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required") } //action = midjRequest.Action mjId = midjRequest.TaskId } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange { if midjRequest.Content == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "content_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } - params := convertSimpleChangeParams(midjRequest.Content) + params := service.ConvertSimpleChangeParams(midjRequest.Content) if params == nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "content_parse_failed", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed") } - mjId = params.ID + mjId = params.TaskId midjRequest.Action = params.Action + } else if relayMode == relayconstant.RelayModeMidjourneyModal { + //if midjRequest.MaskBase64 == "" { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") + //} + mjId = midjRequest.TaskId + midjRequest.Action = constant.MjActionModal } originTask := model.GetByMJId(userId, mjId) if originTask == nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "task_no_found", - } - } else if originTask.Action == "UPSCALE" { - //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest). - return &dto.MidjourneyResponse{ - Code: 4, - Description: "upscale_task_can_not_be_change", - } - } else if originTask.Status != "SUCCESS" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "task_status_is_not_success", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") + } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 - channel, err := model.GetChannelById(originTask.ChannelId, false) + channel, err := model.GetChannelById(originTask.ChannelId, true) if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "channel_not_found", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") + } + if channel.Status != common.ChannelStatusEnabled { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") } c.Set("base_url", channel.GetBaseURL()) c.Set("channel_id", originTask.ChannelId) - log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt + + //if channelType == common.ChannelTypeMidjourneyPlus { + // // plus + //} else { + // // 普通版渠道 + // + //} } - // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - return &dto.MidjourneyResponse{ - Code: 4, - Description: "unmarshal_model_mapping_failed", - } - } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] - isModelMapped = true - } + if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom { + consumeQuota = false } - baseURL := common.ChannelBaseURLs[channelType] + //baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } + baseURL := c.GetString("base_url") //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - log.Printf("fullRequestURL: %s", fullRequestURL) - var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(midjRequest) - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "marshal_text_request_failed", - } - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } - mjAction := "mj_" + strings.ToLower(midjRequest.Action) - modelPrice := common.GetModelPrice(mjAction, true) + modelName := service.CoverActionToModelName(midjRequest.Action) + modelPrice := common.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 if modelPrice == -1 { - defaultPrice, ok := DefaultModelPrice[mjAction] + defaultPrice, ok := common.DefaultModelPrice[modelName] if !ok { modelPrice = 0.1 } else { @@ -423,53 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "create_request_failed", - } + return &midjResponseWithStatus.Response } - //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey")) - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - //mjToken := "" - //if c.Request.Header.Get("ApiKey") != "" { - // mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1] - //} - //req.Header.Set("ApiKey", "Bearer midjourney-proxy") - req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) - // print request header - log.Printf("request header: %s", req.Header) - log.Printf("request body: %s", midjRequest.Prompt) - - resp, err := service.GetHttpClient().Do(req) - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "do_request_failed", - } - } - - err = req.Body.Close() - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "close_request_body_failed", - } - } - err = c.Request.Body.Close() - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "close_request_body_failed", - } - } - var midjResponse dto.MidjourneyResponse + midjResponse := &midjResponseWithStatus.Response defer func(ctx context.Context) { - if consumeQuota { + if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) @@ -481,7 +514,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) @@ -489,41 +522,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } }(c.Request.Context()) - //if consumeQuota { - // - //} - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "read_response_body_failed", - } - } - err = resp.Body.Close() - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "close_response_body_failed", - } - } - - err = json.Unmarshal(responseBody, &midjResponse) - log.Printf("responseBody: %s", string(responseBody)) - log.Printf("midjResponse: %v", midjResponse) - if resp.StatusCode != 200 { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode), - } - } - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "unmarshal_response_body_failed", - } - } - // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md //1-提交成功 // 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}} @@ -575,8 +573,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } //修改返回值 - newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) - responseBody = []byte(newBody) + if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom { + newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) + responseBody = []byte(newBody) + } } err = midjourneyTask.Insert() @@ -593,21 +593,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons responseBody = []byte(newBody) } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) + //for k, v := range resp.Header { + // c.Writer.Header().Set(k, v[0]) + //} + c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) + _, err = io.Copy(c.Writer, bodyReader) if err != nil { return &dto.MidjourneyResponse{ Code: 4, Description: "copy_response_body_failed", } } - err = resp.Body.Close() + err = bodyReader.Close() if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -622,32 +623,3 @@ type taskChangeParams struct { Action string Index int } - -func convertSimpleChangeParams(content string) *taskChangeParams { - split := strings.Split(content, " ") - if len(split) != 2 { - return nil - } - - action := strings.ToLower(split[1]) - changeParams := &taskChangeParams{} - changeParams.ID = split[0] - - if action[0] == 'u' { - changeParams.Action = "UPSCALE" - } else if action[0] == 'v' { - changeParams.Action = "VARIATION" - } else if action == "r" { - changeParams.Action = "REROLL" - return changeParams - } else { - return nil - } - - index, err := strconv.Atoi(action[1:2]) - if err != nil || index < 1 || index > 4 { - return nil - } - changeParams.Index = index - return changeParams -} diff --git a/router/relay-router.go b/router/relay-router.go index 6a30a5a..4addee0 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -47,6 +47,9 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage) relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { + relayMjRouter.POST("/submit/action", controller.RelayMidjourney) + relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney) + relayMjRouter.POST("/submit/modal", controller.RelayMidjourney) relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) relayMjRouter.POST("/submit/change", controller.RelayMidjourney) relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney) @@ -54,7 +57,9 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.POST("/submit/blend", controller.RelayMidjourney) relayMjRouter.POST("/notify", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) + relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) + relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney) } //relayMjRouter.Use() } diff --git a/service/error.go b/service/error.go index 303bcf7..424be5d 100644 --- a/service/error.go +++ b/service/error.go @@ -11,6 +11,20 @@ import ( "strings" ) +func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse { + return &dto.MidjourneyResponse{ + Code: code, + Description: desc, + } +} + +func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode { + return &dto.MidjourneyResponseWithStatusCode{ + StatusCode: statusCode, + Response: *MidjourneyErrorWrapper(code, desc), + } +} + // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { text := err.Error() diff --git a/service/midjourney.go b/service/midjourney.go new file mode 100644 index 0000000..7c47cd6 --- /dev/null +++ b/service/midjourney.go @@ -0,0 +1,224 @@ +package service + +import ( + "context" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "log" + "net/http" + "one-api/constant" + "one-api/dto" + relayconstant "one-api/relay/constant" + "strconv" + "strings" + "time" +) + +func CoverActionToModelName(mjAction string) string { + modelName := "mj_" + strings.ToLower(mjAction) + if mjAction == constant.MjActionSwapFace { + modelName = "swap_face" + } + return modelName +} + +func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) { + action := "" + if relayMode == relayconstant.RelayModeMidjourneyAction { + // plus request + err := CoverPlusActionToNormalAction(midjRequest) + if err != nil { + return "", err, false + } + action = midjRequest.Action + } else { + switch relayMode { + case relayconstant.RelayModeMidjourneyImagine: + action = constant.MjActionImagine + case relayconstant.RelayModeMidjourneyDescribe: + action = constant.MjActionDescribe + case relayconstant.RelayModeMidjourneyBlend: + action = constant.MjActionBlend + case relayconstant.RelayModeMidjourneyShorten: + action = constant.MjActionShorten + case relayconstant.RelayModeMidjourneyChange: + action = midjRequest.Action + case relayconstant.RelayModeMidjourneyModal: + action = constant.MjActionModal + case relayconstant.RelayModeSwapFace: + action = constant.MjActionSwapFace + case relayconstant.RelayModeMidjourneySimpleChange: + params := ConvertSimpleChangeParams(midjRequest.Content) + if params == nil { + return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false + } + action = params.Action + case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify: + return "", nil, true + default: + return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false + } + } + modelName := CoverActionToModelName(action) + return modelName, nil, true +} + +func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { + // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" + customId := midjRequest.CustomId + if customId == "" { + return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") + } + splits := strings.Split(customId, "::") + var action string + if splits[1] == "JOB" { + action = splits[2] + } else { + action = splits[1] + } + + if action == "" { + return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + if strings.Contains(action, "upsample") { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionUpscale + } else if strings.Contains(action, "variation") { + midjRequest.Index = 1 + if action == "variation" { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionVariation + } else if action == "low_variation" { + midjRequest.Action = constant.MjActionLowVariation + } else if action == "high_variation" { + midjRequest.Action = constant.MjActionHighVariation + } + } else if strings.Contains(action, "pan") { + midjRequest.Action = constant.MjActionPan + midjRequest.Index = 1 + } else if strings.Contains(action, "reroll") { + midjRequest.Action = constant.MjActionReRoll + midjRequest.Index = 1 + } else if action == "Outpaint" { + midjRequest.Action = constant.MjActionZoom + midjRequest.Index = 1 + } else if action == "CustomZoom" { + midjRequest.Action = constant.MjActionCustomZoom + midjRequest.Index = 1 + } else if action == "Inpaint" { + midjRequest.Action = constant.MjActionInPaint + midjRequest.Index = 1 + } else { + return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId) + } + return nil +} + +func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { + split := strings.Split(content, " ") + if len(split) != 2 { + return nil + } + + action := strings.ToLower(split[1]) + changeParams := &dto.MidjourneyRequest{} + changeParams.TaskId = split[0] + + if action[0] == 'u' { + changeParams.Action = "UPSCALE" + } else if action[0] == 'v' { + changeParams.Action = "VARIATION" + } else if action == "r" { + changeParams.Action = "REROLL" + return changeParams + } else { + return nil + } + + index, err := strconv.Atoi(action[1:2]) + if err != nil || index < 1 || index > 4 { + return nil + } + changeParams.Index = index + return changeParams +} + +func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { + var nullBytes []byte + //var requestBody io.Reader + //requestBody = c.Request.Body + // read request body to json, delete accountFilter and notifyHook + var mapResult map[string]interface{} + err := json.NewDecoder(c.Request.Body).Decode(&mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + delete(mapResult, "accountFilter") + delete(mapResult, "notifyHook") + //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + // make new request with mapResult + reqBody, err := json.Marshal(mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody))) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) + defer cancel() + resp, err := GetHttpClient().Do(req) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err + } + statusCode := resp.StatusCode + //if statusCode != 200 { + // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil + //} + err = req.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err + } + err = c.Request.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err + } + var midjResponse dto.MidjourneyResponse + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err + } + err = resp.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err + } + + err = json.Unmarshal(responseBody, &midjResponse) + log.Printf("responseBody: %s", string(responseBody)) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + } + //log.Printf("midjResponse: %v", midjResponse) + //for k, v := range resp.Header { + // c.Writer.Header().Set(k, v[0]) + //} + return &dto.MidjourneyResponseWithStatusCode{ + StatusCode: statusCode, + Response: midjResponse, + }, responseBody, nil +} diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 1f71208..90c55f1 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -31,10 +31,30 @@ function renderType(type) { return 放大; case 'VARIATION': return 变换; + case 'HIGH_VARIATION': + return 强变换; + case 'LOW_VARIATION': + return 弱变换; + case 'PAN': + return 平移; case 'DESCRIBE': return 图生文; - case 'BLEAND': + case 'BLEND': return 图混合; + case 'SHORTEN': + return 缩词; + case 'REROLL': + return 重绘; + case 'INPAINT': + return 局部重绘-提交; + case 'ZOOM': + return 变焦; + case 'CUSTOM_ZOOM': + return 自定义变焦-提交; + case 'MODAL': + return 窗口处理; + case 'SWAP_FACE': + return 换脸; default: return 未知; } @@ -46,9 +66,11 @@ function renderCode(code) { case 1: return 已提交; case 21: - return 排队中; + return 等待中; case 22: return 重复提交; + case 0: + return 未提交; default: return 未知; } @@ -68,6 +90,8 @@ function renderStatus(type) { return 执行中; case 'FAILURE': return 失败; + case 'MODAL': + return 窗口等待; default: return 未知; } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index a641a02..bb18d0d 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -1,6 +1,7 @@ export const CHANNEL_OPTIONS = [ {key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'}, {key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'}, + {key: 5, text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus'}, {key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'}, {key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'}, {key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'}, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index ddfb744..2b84011 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -95,6 +95,28 @@ const EditChannel = (props) => { case 26: localModels = ['glm-4', 'glm-4v', 'glm-3-turbo']; break; + case 2: + localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe']; + break; + case 5: + localModels = [ + 'swap_face', + 'mj_imagine', + 'mj_variation', + 'mj_reroll', + 'mj_blend', + 'mj_upscale', + 'mj_describe', + 'mj_zoom', + 'mj_shorten', + 'mj_modal', + 'mj_inpaint', + 'mj_custom_zoom', + 'mj_high_variation', + 'mj_low_variation', + 'mj_pan', + ]; + break; } setInputs((inputs) => ({...inputs, models: localModels})); }