From 37c0c8ebdd252e5c9dcd205200e416e758dd8dc2 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 15:37:01 +0800 Subject: [PATCH 01/16] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC?= =?UTF-8?q?=E5=AE=B9midjourney-proxy-plus?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 2 +- constant/midjourney.go | 16 ++ controller/channel-billing.go | 4 +- controller/midjourney.go | 18 +- controller/relay.go | 37 ++-- dto/midjourney.go | 52 ++++++ middleware/auth.go | 6 - model/midjourney.go | 1 + relay/constant/relay_mode.go | 2 + relay/relay-mj.go | 245 +++++++++++++------------ router/relay-router.go | 2 + service/error.go | 7 + web/src/components/MjLogsTable.js | 6 + web/src/constants/channel.constants.js | 1 + 14 files changed, 246 insertions(+), 153 deletions(-) create mode 100644 constant/midjourney.go diff --git a/common/constants.go b/common/constants.go index 98fa67a..cbb7861 100644 --- a/common/constants.go +++ b/common/constants.go @@ -189,7 +189,7 @@ const ( ChannelTypeMidjourney = 2 ChannelTypeAzure = 3 ChannelTypeOllama = 4 - ChannelTypeOpenAISB = 5 + ChannelTypeMidjourneyPlus = 5 ChannelTypeOpenAIMax = 6 ChannelTypeOhMyGPT = 7 ChannelTypeCustom = 8 diff --git a/constant/midjourney.go b/constant/midjourney.go new file mode 100644 index 0000000..dbcc5c8 --- /dev/null +++ b/constant/midjourney.go @@ -0,0 +1,16 @@ +package constant + +const ( + MjErrorUnknown = 5 + MjRequestError = 4 +) + +const ( + MjActionImagine = "IMAGINE" + MjActionDescribe = "DESCRIBE" + MjActionBlend = "BLEND" + MjActionUpscale = "UPSCALE" + MjActionVariation = "VARIATION" + MjActionInPaint = "INPAINT" + MjActionInPaintPre = "INPAINT_PRE" +) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 4bcd4d4..96f82ee 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return 0, errors.New("尚未实现") case common.ChannelTypeCustom: baseURL = channel.GetBaseURL() - case common.ChannelTypeOpenAISB: - return updateChannelOpenAISBBalance(channel) + //case common.ChannelTypeOpenAISB: + // return updateChannelOpenAISBBalance(channel) case common.ChannelTypeAIProxy: return updateChannelAIProxyBalance(channel) case common.ChannelTypeAPI2GPT: diff --git a/controller/midjourney.go b/controller/midjourney.go index 1a42270..cac253c 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -10,8 +10,8 @@ import ( "log" "net/http" "one-api/common" + "one-api/dto" "one-api/model" - relay2 "one-api/relay" "one-api/service" "strconv" "strings" @@ -75,11 +75,11 @@ import ( responseBody, err := io.ReadAll(resp.Body) resp.Body.Close() log.Printf("responseBody: %s", string(responseBody)) - var responseItem Midjourney + var responseItem MidjourneyDto // err = json.NewDecoder(resp.Body).Decode(&responseItem) err = json.Unmarshal(responseBody, &responseItem) if err != nil { - if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") { + if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") { var responseWithoutStatus MidjourneyWithoutStatus var responseStatus MidjourneyStatus err1 := json.Unmarshal(responseBody, &responseWithoutStatus) @@ -228,12 +228,16 @@ func UpdateMidjourneyTaskBulk() { common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } + if resp.StatusCode != http.StatusOK { + common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + continue + } responseBody, err := io.ReadAll(resp.Body) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } - var responseItems []relay2.Midjourney + var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) @@ -259,6 +263,10 @@ func UpdateMidjourneyTaskBulk() { task.ImageUrl = responseItem.ImageUrl task.Status = responseItem.Status task.FailReason = responseItem.FailReason + if responseItem.Buttons != nil { + buttonStr, _ := json.Marshal(responseItem.Buttons) + task.Buttons = string(buttonStr) + } if task.Progress != "100%" && responseItem.FailReason != "" { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" @@ -286,7 +294,7 @@ func UpdateMidjourneyTaskBulk() { } } -func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool { +func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool { if oldTask.Code != 1 { return true } diff --git a/controller/relay.go b/controller/relay.go index 911a7c5..a42db2e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,7 +62,13 @@ func Relay(c *gin.Context) { func RelayMidjourney(c *gin.Context) { relayMode := relayconstant.RelayModeUnknown - if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { + if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/action") { + // midjourney plus + relayMode = relayconstant.RelayModeMidjourneyAction + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") { + // midjourney plus + relayMode = relayconstant.RelayModeMidjourneyModal + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { relayMode = relayconstant.RelayModeMidjourneyImagine } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") { relayMode = relayconstant.RelayModeMidjourneyBlend @@ -86,35 +92,24 @@ func RelayMidjourney(c *gin.Context) { err = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: err = relay.RelayMidjourneyTask(c, relayMode) + //case relayconstant.RelayModeMidjourneyModal: + // err = relay.RelayMidjournneyModal(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } //err = relayMidjourneySubmit(c, relayMode) log.Println(err) if err != nil { - retryTimesStr := c.Query("retry") - retryTimes, _ := strconv.Atoi(retryTimesStr) - if retryTimesStr == "" { - retryTimes = common.RetryTimes - } - if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) - } else { - if err.Code == 30 { - err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" - } - c.JSON(429, gin.H{ - "error": fmt.Sprintf("%s %s", err.Description, err.Result), - "type": "upstream_error", - }) + if err.Code == 30 { + err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" } + c.JSON(429, gin.H{ + "error": fmt.Sprintf("%s %s", err.Description, err.Result), + "type": "upstream_error", + "code": err.Code, + }) channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) - //if shouldDisableChannel(&err.Error) { - // channelId := c.GetInt("channel_id") - // channelName := c.GetString("channel_name") - // disableChannel(channelId, channelName, err.Result) - //};'''''''''''''''''''''''''''''''' } } diff --git a/dto/midjourney.go b/dto/midjourney.go index 4c67909..a16a65e 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -2,6 +2,8 @@ package dto type MidjourneyRequest struct { Prompt string `json:"prompt"` + CustomId string `json:"customId"` + BotType string `json:"botType"` NotifyHook string `json:"notifyHook"` Action string `json:"action"` Index int `json:"index"` @@ -9,6 +11,7 @@ type MidjourneyRequest struct { TaskId string `json:"taskId"` Base64Array []string `json:"base64Array"` Content string `json:"content"` + MaskBase64 string `json:"maskBase64"` } type MidjourneyResponse struct { @@ -17,3 +20,52 @@ type MidjourneyResponse struct { Properties interface{} `json:"properties"` Result string `json:"result"` } + +type MidjourneyDto struct { + MjId string `json:"id"` + Action string `json:"action"` + CustomId string `json:"customId"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + ImageUrl string `json:"imageUrl"` + Status string `json:"status"` + Progress string `json:"progress"` + FailReason string `json:"failReason"` + Buttons any `json:"buttons"` + MaskBase64 string `json:"maskBase64"` +} + +type MidjourneyStatus struct { + Status int `json:"status"` +} +type MidjourneyWithoutStatus struct { + Id int `json:"id"` + Code int `json:"code"` + UserId int `json:"user_id" gorm:"index"` + Action string `json:"action"` + MjId string `json:"mj_id" gorm:"index"` + Prompt string `json:"prompt"` + PromptEn string `json:"prompt_en"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + ImageUrl string `json:"image_url"` + Progress string `json:"progress"` + FailReason string `json:"fail_reason"` + ChannelId int `json:"channel_id"` +} + +type ActionButton struct { + CustomId any `json:"customId"` + Emoji any `json:"emoji"` + Label any `json:"label"` + Type any `json:"type"` + Style any `json:"style"` +} diff --git a/middleware/auth.go b/middleware/auth.go index ef774f6..a8dac30 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -125,12 +125,6 @@ func TokenAuth() func(c *gin.Context) { } else { c.Set("token_model_limit_enabled", false) } - requestURL := c.Request.URL.String() - consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { - consumeQuota = false - } - c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) diff --git a/model/midjourney.go b/model/midjourney.go index 0ef2e55..f20ab32 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -19,6 +19,7 @@ type Midjourney struct { FailReason string `json:"fail_reason"` ChannelId int `json:"channel_id"` Quota int `json:"quota"` + Buttons string `json:"buttons"` } // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index beea7dc..c49caae 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -21,6 +21,8 @@ const ( RelayModeAudioSpeech RelayModeAudioTranscription RelayModeAudioTranslation + RelayModeMidjourneyAction + RelayModeMidjourneyModal ) func Path2RelayMode(path string) int { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index b2b9926..f667cd1 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -9,6 +9,7 @@ import ( "log" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" relayconstant "one-api/relay/constant" @@ -20,51 +21,15 @@ import ( "github.com/gin-gonic/gin" ) -type Midjourney struct { - MjId string `json:"id"` - Action string `json:"action"` - Prompt string `json:"prompt"` - PromptEn string `json:"promptEn"` - Description string `json:"description"` - State string `json:"state"` - SubmitTime int64 `json:"submitTime"` - StartTime int64 `json:"startTime"` - FinishTime int64 `json:"finishTime"` - ImageUrl string `json:"imageUrl"` - Status string `json:"status"` - Progress string `json:"progress"` - FailReason string `json:"failReason"` -} - -type MidjourneyStatus struct { - Status int `json:"status"` -} -type MidjourneyWithoutStatus struct { - Id int `json:"id"` - Code int `json:"code"` - UserId int `json:"user_id" gorm:"index"` - Action string `json:"action"` - MjId string `json:"mj_id" gorm:"index"` - Prompt string `json:"prompt"` - PromptEn string `json:"prompt_en"` - Description string `json:"description"` - State string `json:"state"` - SubmitTime int64 `json:"submit_time"` - StartTime int64 `json:"start_time"` - FinishTime int64 `json:"finish_time"` - ImageUrl string `json:"image_url"` - Progress string `json:"progress"` - FailReason string `json:"fail_reason"` - ChannelId int `json:"channel_id"` -} - var DefaultModelPrice = map[string]float64{ - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_describe": 0.05, - "mj_upscale": 0.05, + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_inpaint": 0.1, + "mj_inpaint_pre": 0, + "mj_describe": 0.05, + "mj_upscale": 0.05, } func RelayMidjourneyImage(c *gin.Context) { @@ -108,7 +73,7 @@ func RelayMidjourneyImage(c *gin.Context) { } func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { - var midjRequest Midjourney + var midjRequest dto.MidjourneyDto err := common.UnmarshalBodyReusable(c, &midjRequest) if err != nil { return &dto.MidjourneyResponse{ @@ -147,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { return nil } -func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) { +func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { midjourneyTask.MjId = originTask.MjId midjourneyTask.Progress = originTask.Progress midjourneyTask.PromptEn = originTask.PromptEn @@ -167,9 +132,41 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.Action = originTask.Action midjourneyTask.Description = originTask.Description midjourneyTask.Prompt = originTask.Prompt + if originTask.Buttons != "" { + var buttons []dto.ActionButton + err := json.Unmarshal([]byte(originTask.Buttons), &buttons) + if err == nil { + midjourneyTask.Buttons = buttons + } + } return } +func RelayMidjournneyModal(c *gin.Context) *dto.MidjourneyResponse { + userId := c.GetInt("id") + var midjRequest dto.MidjourneyRequest + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") + } + originTask := model.GetByMJId(userId, midjRequest.TaskId) + if originTask == nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") + } + + respBody, err := json.Marshal(midjRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") + } + c.Writer.Header().Set("Content-Type", "application/json") + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") + } + return nil + +} + func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { userId := c.GetInt("id") var err error @@ -184,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse Description: "task_no_found", } } - midjourneyTask := getMidjourneyTaskModel(c, originTask) + midjourneyTask := getMidjourneyTaskDto(c, originTask) respBody, err = json.Marshal(midjourneyTask) if err != nil { return &dto.MidjourneyResponse{ @@ -203,16 +200,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse Description: "do_request_failed", } } - var tasks []Midjourney + var tasks []dto.MidjourneyDto if len(condition.IDs) != 0 { originTasks := model.GetByMJIds(userId, condition.IDs) for _, originTask := range originTasks { - midjourneyTask := getMidjourneyTaskModel(c, originTask) + midjourneyTask := getMidjourneyTaskDto(c, originTask) tasks = append(tasks, midjourneyTask) } } if tasks == nil { - tasks = make([]Midjourney, 0) + tasks = make([]dto.MidjourneyDto, 0) } respBody, err = json.Marshal(tasks) if err != nil { @@ -235,44 +232,32 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse return nil } -const ( - // type 1 根据 mode 价格不同 - MJSubmitActionImagine = "IMAGINE" - MJSubmitActionVariation = "VARIATION" //变换 - MJSubmitActionBlend = "BLEND" //混图 - - MJSubmitActionReroll = "REROLL" //重新生成 - // type 2 固定价格 - MJSubmitActionDescribe = "DESCRIBE" - MJSubmitActionUpscale = "UPSCALE" // 放大 -) - func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { imageModel := "midjourney" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") channelId := c.GetInt("channel_id") + consumeQuota := true var midjRequest dto.MidjourneyRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &midjRequest) - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "bind_request_body_failed", - } + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") + } + + if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 + mjErr := coverPlusActionToNormalAction(&midjRequest) + if mjErr != nil { + return mjErr } + relayMode = relayconstant.RelayModeMidjourneyChange } if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "prompt_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } midjRequest.Action = "IMAGINE" } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 @@ -283,71 +268,58 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons mjId := "" if relayMode == relayconstant.RelayModeMidjourneyChange { if midjRequest.TaskId == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "taskId_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") } else if midjRequest.Action == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "action_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required") } else if midjRequest.Index == 0 { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "index_can_only_be_1_2_3_4", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required") } //action = midjRequest.Action mjId = midjRequest.TaskId } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange { if midjRequest.Content == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "content_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } params := convertSimpleChangeParams(midjRequest.Content) if params == nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "content_parse_failed", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed") } mjId = params.ID midjRequest.Action = params.Action + } else if relayMode == relayconstant.RelayModeMidjourneyModal { + if midjRequest.MaskBase64 == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") + } + mjId = midjRequest.TaskId + midjRequest.Action = "INPAINT" } originTask := model.GetByMJId(userId, mjId) if originTask == nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "task_no_found", - } - } else if originTask.Action == "UPSCALE" { - //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest). - return &dto.MidjourneyResponse{ - Code: 4, - Description: "upscale_task_can_not_be_change", - } - } else if originTask.Status != "SUCCESS" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "task_status_is_not_success", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") + } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 channel, err := model.GetChannelById(originTask.ChannelId, false) if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "channel_not_found", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") } c.Set("base_url", channel.GetBaseURL()) c.Set("channel_id", originTask.ChannelId) - log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) + log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt + + if channelType == common.ChannelTypeMidjourneyPlus { + // plus + } else { + // 普通版渠道 + + } + } + + if midjRequest.Action == constant.MjActionInPaintPre { + consumeQuota = false } // map model name @@ -379,7 +351,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - log.Printf("fullRequestURL: %s", fullRequestURL) var requestBody io.Reader if isModelMapped { @@ -394,6 +365,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } else { requestBody = c.Request.Body } + mjAction := "mj_" + strings.ToLower(midjRequest.Action) modelPrice := common.GetModelPrice(mjAction, true) // 如果没有配置价格,则使用默认价格 @@ -489,9 +461,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } }(c.Request.Context()) - //if consumeQuota { - // - //} responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -651,3 +620,43 @@ func convertSimpleChangeParams(content string) *taskChangeParams { changeParams.Index = index return changeParams } + +func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { + // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" + customId := midjRequest.CustomId + if customId == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") + } + splits := strings.Split(customId, "::") + var action string + if splits[1] == "JOB" { + action = splits[2] + } else { + action = splits[1] + } + + if action == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + if strings.Contains(action, "upsample") { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionUpscale + } else if strings.Contains(action, "variation") { + midjRequest.Action = constant.MjActionVariation + } else if strings.Contains(action, "pan") { + midjRequest.Action = constant.MjActionVariation + midjRequest.Index = 1 + } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") { + midjRequest.Action = constant.MjActionInPaintPre + } else if action == "Inpaint" { + midjRequest.Action = constant.MjActionInPaintPre + midjRequest.Index = 1 + } else { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + return nil +} diff --git a/router/relay-router.go b/router/relay-router.go index 6a30a5a..68b762b 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -47,6 +47,8 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage) relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { + relayMjRouter.POST("/submit/action", controller.RelayMidjourney) + relayMjRouter.POST("/submit/modal", controller.RelayMidjourney) relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) relayMjRouter.POST("/submit/change", controller.RelayMidjourney) relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney) diff --git a/service/error.go b/service/error.go index 303bcf7..91c78c8 100644 --- a/service/error.go +++ b/service/error.go @@ -11,6 +11,13 @@ import ( "strings" ) +func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse { + return &dto.MidjourneyResponse{ + Code: code, + Description: desc, + } +} + // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { text := err.Error() diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 1f71208..4f17c14 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -35,6 +35,10 @@ function renderType(type) { return 图生文; case 'BLEAND': return 图混合; + case 'INPAINT': + return 局部重绘; + case 'INPAINT_PRE': + return 局部重绘-预处理; default: return 未知; } @@ -68,6 +72,8 @@ function renderStatus(type) { return 执行中; case 'FAILURE': return 失败; + case 'MODAL': + return 窗口等待; default: return 未知; } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index a641a02..bb18d0d 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -1,6 +1,7 @@ export const CHANNEL_OPTIONS = [ {key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'}, {key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'}, + {key: 5, text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus'}, {key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'}, {key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'}, {key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'}, From fd3a41bacb326f7d1b15b9b447efafced679d419 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 16:19:22 +0800 Subject: [PATCH 02/16] =?UTF-8?q?feat:=20=E8=AF=B7=E6=B1=82=E8=B6=85?= =?UTF-8?q?=E6=97=B6=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/midjourney.go | 1 + relay/relay-mj.go | 27 +++++++++++++-------------- web/src/components/MjLogsTable.js | 2 ++ web/src/pages/Channel/EditChannel.js | 6 ++++++ 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/dto/midjourney.go b/dto/midjourney.go index a16a65e..4fef4e1 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -25,6 +25,7 @@ type MidjourneyDto struct { MjId string `json:"id"` Action string `json:"action"` CustomId string `json:"customId"` + BotType string `json:"botType"` Prompt string `json:"prompt"` PromptEn string `json:"promptEn"` Description string `json:"description"` diff --git a/relay/relay-mj.go b/relay/relay-mj.go index f667cd1..5fafc89 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -112,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { return nil } -func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { +func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { midjourneyTask.MjId = originTask.MjId midjourneyTask.Progress = originTask.Progress midjourneyTask.PromptEn = originTask.PromptEn @@ -181,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse Description: "task_no_found", } } - midjourneyTask := getMidjourneyTaskDto(c, originTask) + midjourneyTask := coverMidjourneyTaskDto(c, originTask) respBody, err = json.Marshal(midjourneyTask) if err != nil { return &dto.MidjourneyResponse{ @@ -204,7 +204,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse if len(condition.IDs) != 0 { originTasks := model.GetByMJIds(userId, condition.IDs) for _, originTask := range originTasks { - midjourneyTask := getMidjourneyTaskDto(c, originTask) + midjourneyTask := coverMidjourneyTaskDto(c, originTask) tasks = append(tasks, midjourneyTask) } } @@ -403,23 +403,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey")) - + timeout := time.Second * 30 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - //mjToken := "" - //if c.Request.Header.Get("ApiKey") != "" { - // mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1] - //} - //req.Header.Set("ApiKey", "Bearer midjourney-proxy") req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) // print request header - log.Printf("request header: %s", req.Header) - log.Printf("request body: %s", midjRequest.Prompt) + //log.Printf("request header: %s", req.Header) + //log.Printf("request body: %s", midjRequest.Prompt) + defer cancel() resp, err := service.GetHttpClient().Do(req) if err != nil { return &dto.MidjourneyResponse{ - Code: 4, + Code: 5, Description: "do_request_failed", } } @@ -427,14 +426,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons err = req.Body.Close() if err != nil { return &dto.MidjourneyResponse{ - Code: 4, + Code: 5, Description: "close_request_body_failed", } } err = c.Request.Body.Close() if err != nil { return &dto.MidjourneyResponse{ - Code: 4, + Code: 5, Description: "close_request_body_failed", } } diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 4f17c14..4accf54 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -35,6 +35,8 @@ function renderType(type) { return 图生文; case 'BLEAND': return 图混合; + case 'REROLL': + return 重绘; case 'INPAINT': return 局部重绘; case 'INPAINT_PRE': diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 7221b9a..ee79368 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -95,6 +95,12 @@ const EditChannel = (props) => { case 26: localModels = ['glm-4', 'glm-4v', 'glm-3-turbo']; break; + case 2: + localModels = ['midjourney']; + break; + case 5: + localModels = ['midjourney']; + break; } setInputs((inputs) => ({...inputs, models: localModels})); } From 728dbed28d5738d1d3b4a4925f719cb05026af1f Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 16:29:27 +0800 Subject: [PATCH 03/16] =?UTF-8?q?feat:=20=E5=85=BC=E5=AE=B9=E5=8F=98?= =?UTF-8?q?=E7=84=A6=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- constant/midjourney.go | 1 + relay/relay-mj.go | 5 ++++- web/src/components/MjLogsTable.js | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/constant/midjourney.go b/constant/midjourney.go index dbcc5c8..c184435 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -13,4 +13,5 @@ const ( MjActionVariation = "VARIATION" MjActionInPaint = "INPAINT" MjActionInPaintPre = "INPAINT_PRE" + MjActionZoom = "ZOOM" ) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 5fafc89..d582055 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -27,6 +27,7 @@ var DefaultModelPrice = map[string]float64{ "mj_reroll": 0.1, "mj_blend": 0.1, "mj_inpaint": 0.1, + "mj_zoom": 0.1, "mj_inpaint_pre": 0, "mj_describe": 0.05, "mj_upscale": 0.05, @@ -646,11 +647,13 @@ func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj midjRequest.Action = constant.MjActionUpscale } else if strings.Contains(action, "variation") { midjRequest.Action = constant.MjActionVariation + midjRequest.Index = 1 } else if strings.Contains(action, "pan") { midjRequest.Action = constant.MjActionVariation midjRequest.Index = 1 } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") { - midjRequest.Action = constant.MjActionInPaintPre + midjRequest.Action = constant.MjActionZoom + midjRequest.Index = 1 } else if action == "Inpaint" { midjRequest.Action = constant.MjActionInPaintPre midjRequest.Index = 1 diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 4accf54..a1ffeb6 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -39,6 +39,8 @@ function renderType(type) { return 重绘; case 'INPAINT': return 局部重绘; + case 'ZOOM': + return 变焦; case 'INPAINT_PRE': return 局部重绘-预处理; default: From 2ad591411eff3f7f1ef91cc012e25ee915dea550 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 17:46:34 +0800 Subject: [PATCH 04/16] feat: support shorten --- Midjourney.md | 287 ++---------------------------- constant/midjourney.go | 1 + controller/midjourney.go | 4 + controller/relay.go | 3 + dto/midjourney.go | 40 +++-- model/midjourney.go | 1 + relay/constant/relay_mode.go | 1 + relay/relay-mj.go | 58 +++--- router/relay-router.go | 1 + web/src/components/MjLogsTable.js | 2 + 10 files changed, 74 insertions(+), 324 deletions(-) diff --git a/Midjourney.md b/Midjourney.md index fe4d433..becc9c9 100644 --- a/Midjourney.md +++ b/Midjourney.md @@ -7,285 +7,28 @@ ```json { "gpt-4-gizmo-*": 0.1, - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_describe": 0.05, - "mj_upscale": 0.05 + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_inpaint": 0.1, + "mj_zoom": 0.1, + "mj_inpaint_pre": 0, + "mj_describe": 0.05, + "mj_upscale": 0.05, + "swap_face": 0.05 } ``` ## 渠道设置 -### 对接 midjourney-proxy +### 对接 midjourney-proxy(plus) 1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy) -2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney +2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face 3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080 4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填 ### 对接上游new api -1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney -2. 地址填写上游new api的地址,例如:http://localhost:8080 -3. 密钥填写上游new api的密钥 - -## 任务提交 - -### 绘图变化 - -**接口地址**:`/mj/submit/change` - -**请求方式**:`POST` - -**请求数据类型**:`application/json` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求示例**: - -```javascript -{ - "action" -: - "UPSCALE", - "index" -: - 1, - "notifyHook" -: - "", - "state" -: - "", - "taskId" -: - "1320098173412546" -} -``` - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------| -| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 | -|   action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | | -|   index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | | -|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | | -|   state | 自定义参数 | | false | string | | -|   taskId | 任务ID | | true | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 提交结果 | -| 201 | Created | | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|-------------------------------------------|----------------|----------------| -| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) | -| description | 描述 | string | | -| properties | 扩展字段 | object | | -| result | 任务ID | string | | - -**响应示例**: - -```javascript -{ - "code" -: - 1, - "description" -: - "提交成功", - "properties" -: - { - } -, - "result" -: - 1320098173412546 -} -``` - -### 提交Imagine任务 - -**接口地址**:`/mj/submit/imagine` - -**请求方式**:`POST` - -**请求数据类型**:`application/json` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求示例**: - -```javascript -{ - "base64" -: - "", - "notifyHook" -: - "", - "prompt" -: - "Cat", - "state" -: - "" -} -``` - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------------------------|-------------------------|------|-------|-------------|-------------| -| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 | -|   base64 | 垫图base64 | | false | string | | -|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | | -|   prompt | 提示词 | | true | string | | -|   state | 自定义参数 | | false | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 提交结果 | -| 201 | Created | | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|-------------------------------------------|----------------|----------------| -| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) | -| description | 描述 | string | | -| properties | 扩展字段 | object | | -| result | 任务ID | string | | - -**响应示例**: - -```javascript -{ - "code" -: - 1, - "description" -: - "提交成功", - "properties" -: - { - } -, - "result" -: - 1320098173412546 -} -``` - -## 任务查询 - -### 指定ID获取任务 - -**接口地址**:`/mj/task/{id}/fetch` - -**请求方式**:`GET` - -**请求数据类型**:`application/x-www-form-urlencoded` - -**响应数据类型**:`*/*` - -**接口描述**: - -**请求参数**: - -| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema | -|------|------|------|-------|--------|--------| -| id | 任务ID | path | false | string | | - -**响应状态**: - -| 状态码 | 说明 | schema | -|-----|--------------|--------| -| 200 | OK | 任务 | -| 401 | Unauthorized | | -| 403 | Forbidden | | -| 404 | Not Found | | - -**响应参数**: - -| 参数名称 | 参数说明 | 类型 | schema | -|-------------|----------------------------------------------------------|----------------|----------------| -| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | | -| description | 任务描述 | string | | -| failReason | 失败原因 | string | | -| finishTime | 结束时间 | integer(int64) | integer(int64) | -| id | 任务ID | string | | -| imageUrl | 图片url | string | | -| progress | 任务进度 | string | | -| prompt | 提示词 | string | | -| promptEn | 提示词-英文 | string | | -| startTime | 开始执行时间 | integer(int64) | integer(int64) | -| state | 自定义参数 | string | | -| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | | -| submitTime | 提交时间 | integer(int64) | integer(int64) | - -**响应示例**: - -```javascript -{ - "action" -: - "", - "description" -: - "", - "failReason" -: - "", - "finishTime" -: - 0, - "id" -: - "", - "imageUrl" -: - "", - "progress" -: - "", - "prompt" -: - "", - "promptEn" -: - "", - "startTime" -: - 0, - "state" -: - "", - "status" -: - "", - "submitTime" -: - 0 -} -``` \ No newline at end of file +1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face +2. 地址填写上游new api的地址,例如:http://localhost:3000 +3. 密钥填写上游new api的密钥 \ No newline at end of file diff --git a/constant/midjourney.go b/constant/midjourney.go index c184435..a5bccb7 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -14,4 +14,5 @@ const ( MjActionInPaint = "INPAINT" MjActionInPaintPre = "INPAINT_PRE" MjActionZoom = "ZOOM" + MjActionShorten = "SHORTEN" ) diff --git a/controller/midjourney.go b/controller/midjourney.go index cac253c..b666e91 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -263,6 +263,10 @@ func UpdateMidjourneyTaskBulk() { task.ImageUrl = responseItem.ImageUrl task.Status = responseItem.Status task.FailReason = responseItem.FailReason + if responseItem.Properties != nil { + propertiesStr, _ := json.Marshal(responseItem.Properties) + task.Properties = string(propertiesStr) + } if responseItem.Buttons != nil { buttonStr, _ := json.Marshal(responseItem.Buttons) task.Buttons = string(buttonStr) diff --git a/controller/relay.go b/controller/relay.go index a42db2e..7652840 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -68,6 +68,9 @@ func RelayMidjourney(c *gin.Context) { } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") { // midjourney plus relayMode = relayconstant.RelayModeMidjourneyModal + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/shorten") { + // midjourney plus + relayMode = relayconstant.RelayModeMidjourneyShorten } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { relayMode = relayconstant.RelayModeMidjourneyImagine } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") { diff --git a/dto/midjourney.go b/dto/midjourney.go index 4fef4e1..d3b19d5 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -22,23 +22,24 @@ type MidjourneyResponse struct { } type MidjourneyDto struct { - MjId string `json:"id"` - Action string `json:"action"` - CustomId string `json:"customId"` - BotType string `json:"botType"` - Prompt string `json:"prompt"` - PromptEn string `json:"promptEn"` - Description string `json:"description"` - State string `json:"state"` - SubmitTime int64 `json:"submitTime"` - StartTime int64 `json:"startTime"` - FinishTime int64 `json:"finishTime"` - ImageUrl string `json:"imageUrl"` - Status string `json:"status"` - Progress string `json:"progress"` - FailReason string `json:"failReason"` - Buttons any `json:"buttons"` - MaskBase64 string `json:"maskBase64"` + MjId string `json:"id"` + Action string `json:"action"` + CustomId string `json:"customId"` + BotType string `json:"botType"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + ImageUrl string `json:"imageUrl"` + Status string `json:"status"` + Progress string `json:"progress"` + FailReason string `json:"failReason"` + Buttons any `json:"buttons"` + MaskBase64 string `json:"maskBase64"` + Properties *Properties `json:"properties"` } type MidjourneyStatus struct { @@ -70,3 +71,8 @@ type ActionButton struct { Type any `json:"type"` Style any `json:"style"` } + +type Properties struct { + FinalPrompt string `json:"finalPrompt"` + FinalZhPrompt string `json:"finalZhPrompt"` +} diff --git a/model/midjourney.go b/model/midjourney.go index f20ab32..dd065a3 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -20,6 +20,7 @@ type Midjourney struct { ChannelId int `json:"channel_id"` Quota int `json:"quota"` Buttons string `json:"buttons"` + Properties string `json:"properties"` } // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index c49caae..d8dc7ee 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -23,6 +23,7 @@ const ( RelayModeAudioTranslation RelayModeMidjourneyAction RelayModeMidjourneyModal + RelayModeMidjourneyShorten ) func Path2RelayMode(path string) int { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index d582055..a1f6ed4 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -31,6 +31,7 @@ var DefaultModelPrice = map[string]float64{ "mj_inpaint_pre": 0, "mj_describe": 0.05, "mj_upscale": 0.05, + "swap_face": 0.05, } func RelayMidjourneyImage(c *gin.Context) { @@ -140,6 +141,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.Buttons = buttons } } + if originTask.Properties != "" { + var properties dto.Properties + err := json.Unmarshal([]byte(originTask.Properties), &properties) + if err == nil { + midjourneyTask.Properties = &properties + } + } return } @@ -260,9 +268,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if midjRequest.Prompt == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } - midjRequest.Action = "IMAGINE" + midjRequest.Action = constant.MjActionImagine } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 - midjRequest.Action = "DESCRIBE" + midjRequest.Action = constant.MjActionDescribe + } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only + midjRequest.Action = constant.MjActionShorten } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = "BLEND" } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 @@ -292,7 +302,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") } mjId = midjRequest.TaskId - midjRequest.Action = "INPAINT" + midjRequest.Action = constant.MjActionInPaint } originTask := model.GetByMJId(userId, mjId) @@ -418,25 +428,16 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons defer cancel() resp, err := service.GetHttpClient().Do(req) if err != nil { - return &dto.MidjourneyResponse{ - Code: 5, - Description: "do_request_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed") } err = req.Body.Close() if err != nil { - return &dto.MidjourneyResponse{ - Code: 5, - Description: "close_request_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed") } err = c.Request.Body.Close() if err != nil { - return &dto.MidjourneyResponse{ - Code: 5, - Description: "close_request_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed") } var midjResponse dto.MidjourneyResponse @@ -464,33 +465,20 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons responseBody, err := io.ReadAll(resp.Body) if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "read_response_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed") } err = resp.Body.Close() if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "close_response_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed") + } + if resp.StatusCode != 200 { + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status") } - err = json.Unmarshal(responseBody, &midjResponse) log.Printf("responseBody: %s", string(responseBody)) log.Printf("midjResponse: %v", midjResponse) - if resp.StatusCode != 200 { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode), - } - } if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "unmarshal_response_body_failed", - } + return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed") } // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md @@ -651,7 +639,7 @@ func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj } else if strings.Contains(action, "pan") { midjRequest.Action = constant.MjActionVariation midjRequest.Index = 1 - } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") { + } else if action == "Outpaint" || action == "CustomZoom" { midjRequest.Action = constant.MjActionZoom midjRequest.Index = 1 } else if action == "Inpaint" { diff --git a/router/relay-router.go b/router/relay-router.go index 68b762b..f572d8f 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -48,6 +48,7 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { relayMjRouter.POST("/submit/action", controller.RelayMidjourney) + relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney) relayMjRouter.POST("/submit/modal", controller.RelayMidjourney) relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) relayMjRouter.POST("/submit/change", controller.RelayMidjourney) diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index a1ffeb6..fe6554e 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -35,6 +35,8 @@ function renderType(type) { return 图生文; case 'BLEAND': return 图混合; + case 'SHORTEN': + return 缩词; case 'REROLL': return 重绘; case 'INPAINT': From d5ffaf25027feb11e0564febb4fb31b6d3abdb56 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 18:26:16 +0800 Subject: [PATCH 05/16] =?UTF-8?q?feat:=20=E6=93=8D=E4=BD=9C=E7=BB=86?= =?UTF-8?q?=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Midjourney.md | 55 ++++++++++++++++++++------ constant/midjourney.go | 22 ++++++----- controller/relay.go | 32 +-------------- relay/constant/relay_mode.go | 31 +++++++++++++++ relay/relay-mj.go | 66 +++++++++++++------------------ web/src/components/MjLogsTable.js | 6 +++ 6 files changed, 122 insertions(+), 90 deletions(-) diff --git a/Midjourney.md b/Midjourney.md index becc9c9..d495e84 100644 --- a/Midjourney.md +++ b/Midjourney.md @@ -4,31 +4,62 @@ ## 模型价格设置(在设置-运营设置-模型固定价格设置中设置) +### 模型列表 + +### midjourney-proxy支持 + +- mj_imagine (绘图) +- mj_variation (变换) +- mj_reroll (重绘) +- mj_blend (混合) +- mj_upscale (放大) +- mj_describe (图生文) + +### 仅midjourney-proxy-plus支持 + +- mj_zoom (比例变焦) +- mj_shorten (提示词缩短) +- mj_inpaint_pre (发起局部重绘,必须和mj_inpaint一同添加) +- mj_inpaint (局部重绘提交,必须和mj_inpaint_pre一同添加) +- mj_high_variation (强变换) +- mj_low_variation (弱变换) +- mj_pan (平移) +- swap_face (换脸) + ```json { - "gpt-4-gizmo-*": 0.1, - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_inpaint": 0.1, - "mj_zoom": 0.1, + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_inpaint": 0.1, + "mj_zoom": 0.1, + "mj_shorten": 0.1, + "mj_high_variation": 0.1, + "mj_low_variation": 0.1, + "mj_pan": 0.1, "mj_inpaint_pre": 0, - "mj_describe": 0.05, - "mj_upscale": 0.05, - "swap_face": 0.05 + "mj_describe": 0.05, + "mj_upscale": 0.05, + "swap_face": 0.05 } ``` ## 渠道设置 ### 对接 midjourney-proxy(plus) -1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy) -2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face + +1. + +部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy) + +2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus** + ,模型选择midjourney,如果有换脸模型,可以选择swap_face 3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080 4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填 ### 对接上游new api + 1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face 2. 地址填写上游new api的地址,例如:http://localhost:3000 3. 密钥填写上游new api的密钥 \ No newline at end of file diff --git a/constant/midjourney.go b/constant/midjourney.go index a5bccb7..5435a43 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -6,13 +6,17 @@ const ( ) const ( - MjActionImagine = "IMAGINE" - MjActionDescribe = "DESCRIBE" - MjActionBlend = "BLEND" - MjActionUpscale = "UPSCALE" - MjActionVariation = "VARIATION" - MjActionInPaint = "INPAINT" - MjActionInPaintPre = "INPAINT_PRE" - MjActionZoom = "ZOOM" - MjActionShorten = "SHORTEN" + MjActionImagine = "IMAGINE" + MjActionDescribe = "DESCRIBE" + MjActionBlend = "BLEND" + MjActionUpscale = "UPSCALE" + MjActionVariation = "VARIATION" + MjActionInPaint = "INPAINT" + MjActionInPaintPre = "INPAINT_PRE" + MjActionZoom = "ZOOM" + MjActionShorten = "SHORTEN" + MjActionHighVariation = "HIGH_VARIATION" + MjActionLowVariation = "LOW_VARIATION" + MjActionPan = "PAN" + SwapFace = "SWAP_FACE" ) diff --git a/controller/relay.go b/controller/relay.go index 7652840..d35c6a2 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -12,7 +12,6 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "strconv" - "strings" ) func Relay(c *gin.Context) { @@ -61,42 +60,13 @@ func Relay(c *gin.Context) { } func RelayMidjourney(c *gin.Context) { - relayMode := relayconstant.RelayModeUnknown - if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/action") { - // midjourney plus - relayMode = relayconstant.RelayModeMidjourneyAction - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") { - // midjourney plus - relayMode = relayconstant.RelayModeMidjourneyModal - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/shorten") { - // midjourney plus - relayMode = relayconstant.RelayModeMidjourneyShorten - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { - relayMode = relayconstant.RelayModeMidjourneyImagine - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") { - relayMode = relayconstant.RelayModeMidjourneyBlend - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") { - relayMode = relayconstant.RelayModeMidjourneyDescribe - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") { - relayMode = relayconstant.RelayModeMidjourneyNotify - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") { - relayMode = relayconstant.RelayModeMidjourneyChange - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") { - relayMode = relayconstant.RelayModeMidjourneyChange - } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") { - relayMode = relayconstant.RelayModeMidjourneyTaskFetch - } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") { - relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition - } - + relayMode := constant.Path2RelayModeMidjourney(c.Request.URL.Path) var err *dto.MidjourneyResponse switch relayMode { case relayconstant.RelayModeMidjourneyNotify: err = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: err = relay.RelayMidjourneyTask(c, relayMode) - //case relayconstant.RelayModeMidjourneyModal: - // err = relay.RelayMidjournneyModal(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index d8dc7ee..9f13726 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -51,3 +51,34 @@ func Path2RelayMode(path string) int { } return relayMode } + +func Path2RelayModeMidjourney(path string) int { + relayMode := RelayModeUnknown + if strings.HasPrefix(path, "/mj/submit/action") { + // midjourney plus + relayMode = RelayModeMidjourneyAction + } else if strings.HasPrefix(path, "/mj/submit/modal") { + // midjourney plus + relayMode = RelayModeMidjourneyModal + } else if strings.HasPrefix(path, "/mj/submit/shorten") { + // midjourney plus + relayMode = RelayModeMidjourneyShorten + } else if strings.HasPrefix(path, "/mj/submit/imagine") { + relayMode = RelayModeMidjourneyImagine + } else if strings.HasPrefix(path, "/mj/submit/blend") { + relayMode = RelayModeMidjourneyBlend + } else if strings.HasPrefix(path, "/mj/submit/describe") { + relayMode = RelayModeMidjourneyDescribe + } else if strings.HasPrefix(path, "/mj/notify") { + relayMode = RelayModeMidjourneyNotify + } else if strings.HasPrefix(path, "/mj/submit/change") { + relayMode = RelayModeMidjourneyChange + } else if strings.HasPrefix(path, "/mj/submit/simple-change") { + relayMode = RelayModeMidjourneyChange + } else if strings.HasSuffix(path, "/fetch") { + relayMode = RelayModeMidjourneyTaskFetch + } else if strings.HasSuffix(path, "/list-by-condition") { + relayMode = RelayModeMidjourneyTaskFetchByCondition + } + return relayMode +} diff --git a/relay/relay-mj.go b/relay/relay-mj.go index a1f6ed4..f391f14 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -22,16 +22,20 @@ import ( ) var DefaultModelPrice = map[string]float64{ - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_inpaint": 0.1, - "mj_zoom": 0.1, - "mj_inpaint_pre": 0, - "mj_describe": 0.05, - "mj_upscale": 0.05, - "swap_face": 0.05, + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_inpaint": 0.1, + "mj_zoom": 0.1, + "mj_shorten": 0.1, + "mj_high_variation": 0.1, + "mj_low_variation": 0.1, + "mj_pan": 0.1, + "mj_inpaint_pre": 0, + "mj_describe": 0.05, + "mj_upscale": 0.05, + "swap_face": 0.05, } func RelayMidjourneyImage(c *gin.Context) { @@ -151,31 +155,6 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } -func RelayMidjournneyModal(c *gin.Context) *dto.MidjourneyResponse { - userId := c.GetInt("id") - var midjRequest dto.MidjourneyRequest - err := common.UnmarshalBodyReusable(c, &midjRequest) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") - } - originTask := model.GetByMJId(userId, midjRequest.TaskId) - if originTask == nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") - } - - respBody, err := json.Marshal(midjRequest) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") - } - c.Writer.Header().Set("Content-Type", "application/json") - _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") - } - return nil - -} - func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { userId := c.GetInt("id") var err error @@ -274,7 +253,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only midjRequest.Action = constant.MjActionShorten } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 - midjRequest.Action = "BLEND" + midjRequest.Action = constant.MjActionBlend } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" if relayMode == relayconstant.RelayModeMidjourneyChange { @@ -634,10 +613,21 @@ func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj midjRequest.Index = index midjRequest.Action = constant.MjActionUpscale } else if strings.Contains(action, "variation") { - midjRequest.Action = constant.MjActionVariation midjRequest.Index = 1 + if action == "variation" { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionVariation + } else if action == "low_variation" { + midjRequest.Action = constant.MjActionLowVariation + } else if action == "high_variation" { + midjRequest.Action = constant.MjActionHighVariation + } } else if strings.Contains(action, "pan") { - midjRequest.Action = constant.MjActionVariation + midjRequest.Action = constant.MjActionPan midjRequest.Index = 1 } else if action == "Outpaint" || action == "CustomZoom" { midjRequest.Action = constant.MjActionZoom diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index fe6554e..603d345 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -31,6 +31,12 @@ function renderType(type) { return 放大; case 'VARIATION': return 变换; + case 'HIGH_VARIATION': + return 强变换; + case 'LOW_VARIATION': + return 弱变换; + case 'PAN': + return 平移; case 'DESCRIBE': return 图生文; case 'BLEAND': From 3d10c9f090300c5653823556ca224a5d86cb86e7 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 21:19:48 +0800 Subject: [PATCH 06/16] =?UTF-8?q?feat:=20=E5=B0=86=E6=93=8D=E4=BD=9C?= =?UTF-8?q?=E6=8B=86=E5=88=86=E6=88=90=E5=8D=95=E7=8B=AC=E7=9A=84=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/model-ratio.go | 32 +++-- constant/midjourney.go | 18 +++ controller/midjourney.go | 131 ------------------ controller/model.go | 18 ++- controller/relay.go | 12 +- dto/midjourney.go | 7 + middleware/auth.go | 8 +- middleware/distributor.go | 129 +++++++++++------- middleware/utils.go | 12 +- relay/relay-mj.go | 195 +++++++-------------------- service/midjourney.go | 135 +++++++++++++++++++ web/src/pages/Channel/EditChannel.js | 19 ++- 12 files changed, 366 insertions(+), 350 deletions(-) create mode 100644 service/midjourney.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 791f733..153b748 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -94,17 +94,30 @@ var ModelRatio = map[string]float64{ "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 } -var ModelPrice = map[string]float64{ - "gpt-4-gizmo-*": 0.1, - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_describe": 0.05, - "mj_upscale": 0.05, +var DefaultModelPrice = map[string]float64{ + "gpt-4-gizmo-*": 0.1, + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_inpaint": 0.1, + "mj_zoom": 0.1, + "mj_shorten": 0.1, + "mj_high_variation": 0.1, + "mj_low_variation": 0.1, + "mj_pan": 0.1, + "mj_inpaint_pre": 0, + "mj_describe": 0.05, + "mj_upscale": 0.05, + "swap_face": 0.05, } +var ModelPrice = map[string]float64{} + func ModelPrice2JSONString() string { + if len(ModelPrice) == 0 { + ModelPrice = DefaultModelPrice + } jsonBytes, err := json.Marshal(ModelPrice) if err != nil { SysError("error marshalling model price: " + err.Error()) @@ -118,6 +131,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error { } func GetModelPrice(name string, printErr bool) float64 { + if len(ModelPrice) == 0 { + ModelPrice = DefaultModelPrice + } if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } diff --git a/constant/midjourney.go b/constant/midjourney.go index 5435a43..92e2f23 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -11,6 +11,7 @@ const ( MjActionBlend = "BLEND" MjActionUpscale = "UPSCALE" MjActionVariation = "VARIATION" + MjActionReRoll = "REROLL" MjActionInPaint = "INPAINT" MjActionInPaintPre = "INPAINT_PRE" MjActionZoom = "ZOOM" @@ -20,3 +21,20 @@ const ( MjActionPan = "PAN" SwapFace = "SWAP_FACE" ) + +var MidjourneyModel2Action = map[string]string{ + "mj_imagine": MjActionImagine, + "mj_describe": MjActionDescribe, + "mj_blend": MjActionBlend, + "mj_upscale": MjActionUpscale, + "mj_variation": MjActionVariation, + "mj_reroll": MjActionReRoll, + "mj_inpaint": MjActionInPaint, + "mj_inpaint_pre": MjActionInPaintPre, + "mj_zoom": MjActionZoom, + "mj_shorten": MjActionShorten, + "mj_high_variation": MjActionHighVariation, + "mj_low_variation": MjActionLowVariation, + "mj_pan": MjActionPan, + "swap_face": SwapFace, +} diff --git a/controller/midjourney.go b/controller/midjourney.go index b666e91..6256471 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -18,137 +18,6 @@ import ( "time" ) -/*func UpdateMidjourneyTask() { - //revocer - //imageModel := "midjourney" - ctx := context.TODO() - imageModel := "midjourney" - defer func() { - if err := recover(); err != nil { - log.Printf("UpdateMidjourneyTask panic: %v", err) - } - }() - for { - time.Sleep(time.Duration(15) * time.Second) - tasks := model.GetAllUnFinishTasks() - if len(tasks) != 0 { - common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) - for _, task := range tasks { - common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task)) - midjourneyChannel, err := model.GetChannelById(task.ChannelId, true) - if err != nil { - common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err)) - task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId) - task.Status = "FAILURE" - task.Progress = "100%" - err := task.Update() - if err != nil { - common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) - continue - } - continue - } - requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId) - common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl)) - - req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte(""))) - if err != nil { - common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err)) - continue - } - - // 设置超时时间 - timeout := time.Second * 5 - ctx, cancel := context.WithTimeout(context.Background(), timeout) - - // 使用带有超时的 context 创建新的请求 - req = req.WithContext(ctx) - - req.Header.Set("Content-Type", "application/json") - //req.Header.Set("ApiKey", "Bearer midjourney-proxy") - req.Header.Set("mj-api-secret", midjourneyChannel.Key) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf("UpdateMidjourneyTask error: %v", err) - continue - } - responseBody, err := io.ReadAll(resp.Body) - resp.Body.Close() - log.Printf("responseBody: %s", string(responseBody)) - var responseItem MidjourneyDto - // err = json.NewDecoder(resp.Body).Decode(&responseItem) - err = json.Unmarshal(responseBody, &responseItem) - if err != nil { - if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") { - var responseWithoutStatus MidjourneyWithoutStatus - var responseStatus MidjourneyStatus - err1 := json.Unmarshal(responseBody, &responseWithoutStatus) - err2 := json.Unmarshal(responseBody, &responseStatus) - if err1 == nil && err2 == nil { - jsonData, err3 := json.Marshal(responseWithoutStatus) - if err3 != nil { - log.Printf("UpdateMidjourneyTask error1: %v", err3) - continue - } - err4 := json.Unmarshal(jsonData, &responseStatus) - if err4 != nil { - log.Printf("UpdateMidjourneyTask error2: %v", err4) - continue - } - responseItem.Status = strconv.Itoa(responseStatus.Status) - } else { - log.Printf("UpdateMidjourneyTask error3: %v", err) - continue - } - } else { - log.Printf("UpdateMidjourneyTask error4: %v", err) - continue - } - } - task.Code = 1 - task.Progress = responseItem.Progress - task.PromptEn = responseItem.PromptEn - task.State = responseItem.State - task.SubmitTime = responseItem.SubmitTime - task.StartTime = responseItem.StartTime - task.FinishTime = responseItem.FinishTime - task.ImageUrl = responseItem.ImageUrl - task.Status = responseItem.Status - task.FailReason = responseItem.FailReason - if task.Progress != "100%" && responseItem.FailReason != "" { - common.LogWarn(task.MjId + " 构建失败," + task.FailReason) - task.Progress = "100%" - err = model.CacheUpdateUserQuota(task.UserId) - if err != nil { - log.Println("error update user quota cache: " + err.Error()) - } else { - modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio("default") - ratio := modelRatio * groupRatio - quota := int(ratio * 1 * 1000) - if quota != 0 { - err := model.IncreaseUserQuota(task.UserId, quota) - if err != nil { - log.Println("fail to increase user quota") - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } - } - - err = task.Update() - if err != nil { - log.Printf("UpdateMidjourneyTask error5: %v", err) - } - log.Printf("UpdateMidjourneyTask success: %v", task) - cancel() - } - } - } -} -*/ - func UpdateMidjourneyTaskBulk() { //imageModel := "midjourney" ctx := context.TODO() diff --git a/controller/model.go b/controller/model.go index 38c6c46..9a106aa 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" + "one-api/constant" "one-api/dto" "one-api/model" "one-api/relay" "one-api/relay/channel/ai360" "one-api/relay/channel/moonshot" - "one-api/relay/constant" + relayconstant "one-api/relay/constant" ) // https://platform.openai.com/docs/api-reference/models/list @@ -59,8 +60,8 @@ func init() { IsBlocking: false, }) // https://platform.openai.com/docs/models/model-endpoint-compatibility - for i := 0; i < constant.APITypeDummy; i++ { - if i == constant.APITypeAIProxyLibrary { + for i := 0; i < relayconstant.APITypeDummy; i++ { + if i == relayconstant.APITypeAIProxyLibrary { continue } adaptor := relay.GetAdaptor(i) @@ -100,6 +101,17 @@ func init() { Parent: nil, }) } + for modelName, _ := range constant.MidjourneyModel2Action { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "midjourney", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model diff --git a/controller/relay.go b/controller/relay.go index d35c6a2..fa5493a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -60,7 +60,7 @@ func Relay(c *gin.Context) { } func RelayMidjourney(c *gin.Context) { - relayMode := constant.Path2RelayModeMidjourney(c.Request.URL.Path) + relayMode := c.GetInt("relay_mode") var err *dto.MidjourneyResponse switch relayMode { case relayconstant.RelayModeMidjourneyNotify: @@ -73,13 +73,15 @@ func RelayMidjourney(c *gin.Context) { //err = relayMidjourneySubmit(c, relayMode) log.Println(err) if err != nil { + statusCode := http.StatusBadRequest if err.Code == 30 { err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + statusCode = http.StatusTooManyRequests } - c.JSON(429, gin.H{ - "error": fmt.Sprintf("%s %s", err.Description, err.Result), - "type": "upstream_error", - "code": err.Code, + c.JSON(statusCode, gin.H{ + "description": fmt.Sprintf("%s %s", err.Description, err.Result), + "type": "upstream_error", + "code": err.Code, }) channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) diff --git a/dto/midjourney.go b/dto/midjourney.go index d3b19d5..f81756c 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -1,5 +1,12 @@ package dto +//type SimpleMjRequest struct { +// Prompt string `json:"prompt"` +// CustomId string `json:"customId"` +// Action string `json:"action"` +// Content string `json:"content"` +//} + type MidjourneyRequest struct { Prompt string `json:"prompt"` CustomId string `json:"customId"` diff --git a/middleware/auth.go b/middleware/auth.go index a8dac30..4b865c2 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) { } token, err := model.ValidateUserToken(key) if err != nil { - abortWithMessage(c, http.StatusUnauthorized, err.Error()) + abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { - abortWithMessage(c, http.StatusInternalServerError, err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) return } if !userEnabled { - abortWithMessage(c, http.StatusForbidden, "用户已被封禁") + abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") return } c.Set("id", token.UserId) @@ -129,7 +129,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) } else { - abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") + abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 1ca43dd..c1b3ccb 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -4,7 +4,11 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" + "one-api/dto" "one-api/model" + relayconstant "one-api/relay/constant" + "one-api/service" "strconv" "strings" @@ -23,32 +27,58 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } channel, err = model.GetChannelById(id, true) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } if channel.Status != common.ChannelStatusEnabled { - abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") + abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { + shouldSelectChannel := true // Select a channel for the user var modelRequest ModelRequest var err error if strings.HasPrefix(c.Request.URL.Path, "/mj") { - // Midjourney - if modelRequest.Model == "" { - modelRequest.Model = "midjourney" + relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) + if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || + relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || + relayMode == relayconstant.RelayModeMidjourneyNotify { + shouldSelectChannel = false + } else { + midjourneyRequest := dto.MidjourneyRequest{} + err = common.UnmarshalBodyReusable(c, &midjourneyRequest) + if err != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) + return + } + midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) + if mjErr != nil { + abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) + return + } + if midjourneyModel == "" { + if !success { + abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") + return + } else { + // task fetch, task fetch by condition, notify + shouldSelectChannel = false + } + } + modelRequest.Model = midjourneyModel } + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) return } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { @@ -87,60 +117,61 @@ func Distribute() func(c *gin.Context) { } if tokenModelLimit != nil { if _, ok := tokenModelLimit[modelRequest.Model]; !ok { - abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) return } } else { // token model limit is empty, all models are not allowed - abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") return } } userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) - - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) - // 如果错误,但是渠道不为空,说明是数据库一致性问题 - if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) - message = "数据库一致性已被破坏,请联系管理员" + if shouldSelectChannel { + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + if err != nil { + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + // 如果错误,但是渠道不为空,说明是数据库一致性问题 + if channel != nil { + common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + message = "数据库一致性已被破坏,请联系管理员" + } + // 如果错误,而且渠道为空,说明是没有可用渠道 + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) + return + } + if channel == nil { + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) + return + } + c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) + ban := true + // parse *int to bool + if channel.AutoBan != nil && *channel.AutoBan == 0 { + ban = false + } + c.Set("auto_ban", ban) + c.Set("model_mapping", channel.GetModelMapping()) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set("base_url", channel.GetBaseURL()) + // TODO: api_version统一 + switch channel.Type { + case common.ChannelTypeAzure: + c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + //case common.ChannelTypeAIProxyLibrary: + // c.Set("library_id", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } - // 如果错误,而且渠道为空,说明是没有可用渠道 - abortWithMessage(c, http.StatusServiceUnavailable, message) - return } - if channel == nil { - abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) - return - } - } - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } - c.Set("auto_ban", ban) - c.Set("model_mapping", channel.GetModelMapping()) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) - // TODO: api_version统一 - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) } c.Next() } diff --git a/middleware/utils.go b/middleware/utils.go index 021002d..43801c1 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -5,7 +5,7 @@ import ( "one-api/common" ) -func abortWithMessage(c *gin.Context, statusCode int, message string) { +func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), @@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { c.Abort() common.LogError(c.Request.Context(), message) } + +func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { + c.JSON(statusCode, gin.H{ + "description": description, + "type": "new_api_error", + "code": code, + }) + c.Abort() + common.LogError(c.Request.Context(), description) +} diff --git a/relay/relay-mj.go b/relay/relay-mj.go index f391f14..6cdd9e0 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -21,23 +21,6 @@ import ( "github.com/gin-gonic/gin" ) -var DefaultModelPrice = map[string]float64{ - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_inpaint": 0.1, - "mj_zoom": 0.1, - "mj_shorten": 0.1, - "mj_high_variation": 0.1, - "mj_low_variation": 0.1, - "mj_pan": 0.1, - "mj_inpaint_pre": 0, - "mj_describe": 0.05, - "mj_upscale": 0.05, - "swap_face": 0.05, -} - func RelayMidjourneyImage(c *gin.Context) { taskId := c.Param("id") midjourneyTask := model.GetByOnlyMJId(taskId) @@ -221,10 +204,9 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse } func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - imageModel := "midjourney" tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") + //channelType := c.GetInt("channel") userId := c.GetInt("id") group := c.GetString("group") channelId := c.GetInt("channel_id") @@ -236,7 +218,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 - mjErr := coverPlusActionToNormalAction(&midjRequest) + mjErr := service.CoverPlusActionToNormalAction(&midjRequest) if mjErr != nil { return mjErr } @@ -270,11 +252,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if midjRequest.Content == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } - params := convertSimpleChangeParams(midjRequest.Content) + params := service.ConvertSimpleChangeParams(midjRequest.Content) if params == nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed") } - mjId = params.ID + mjId = params.TaskId midjRequest.Action = params.Action } else if relayMode == relayconstant.RelayModeMidjourneyModal { if midjRequest.MaskBase64 == "" { @@ -294,18 +276,21 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") } + if channel.Status != common.ChannelStatusEnabled { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") + } c.Set("base_url", channel.GetBaseURL()) c.Set("channel_id", originTask.ChannelId) log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt - if channelType == common.ChannelTypeMidjourneyPlus { - // plus - } else { - // 普通版渠道 - - } + //if channelType == common.ChannelTypeMidjourneyPlus { + // // plus + //} else { + // // 普通版渠道 + // + //} } if midjRequest.Action == constant.MjActionInPaintPre { @@ -313,54 +298,52 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - return &dto.MidjourneyResponse{ - Code: 4, - Description: "unmarshal_model_mapping_failed", - } - } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] - isModelMapped = true - } - } + //modelMapping := c.GetString("model_mapping") + //isModelMapped := false + //if modelMapping != "" { + // modelMap := make(map[string]string) + // err := json.Unmarshal([]byte(modelMapping), &modelMap) + // if err != nil { + // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + // return &dto.MidjourneyResponse{ + // Code: 4, + // Description: "unmarshal_model_mapping_failed", + // } + // } + // if modelMap[imageModel] != "" { + // imageModel = modelMap[imageModel] + // isModelMapped = true + // } + //} - baseURL := common.ChannelBaseURLs[channelType] + //baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } + baseURL := c.GetString("base_url") //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(midjRequest) - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "marshal_text_request_failed", - } - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } + //if isModelMapped { + // jsonStr, err := json.Marshal(midjRequest) + // if err != nil { + // return &dto.MidjourneyResponse{ + // Code: 4, + // Description: "marshal_text_request_failed", + // } + // } + // requestBody = bytes.NewBuffer(jsonStr) + //} else { + //} + requestBody = c.Request.Body - mjAction := "mj_" + strings.ToLower(midjRequest.Action) - modelPrice := common.GetModelPrice(mjAction, true) + modelName := service.CoverActionToModelName(midjRequest.Action) + modelPrice := common.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 if modelPrice == -1 { - defaultPrice, ok := DefaultModelPrice[mjAction] + defaultPrice, ok := common.DefaultModelPrice[modelName] if !ok { modelPrice = 0.1 } else { @@ -433,7 +416,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) @@ -558,85 +541,3 @@ type taskChangeParams struct { Action string Index int } - -func convertSimpleChangeParams(content string) *taskChangeParams { - split := strings.Split(content, " ") - if len(split) != 2 { - return nil - } - - action := strings.ToLower(split[1]) - changeParams := &taskChangeParams{} - changeParams.ID = split[0] - - if action[0] == 'u' { - changeParams.Action = "UPSCALE" - } else if action[0] == 'v' { - changeParams.Action = "VARIATION" - } else if action == "r" { - changeParams.Action = "REROLL" - return changeParams - } else { - return nil - } - - index, err := strconv.Atoi(action[1:2]) - if err != nil || index < 1 || index > 4 { - return nil - } - changeParams.Index = index - return changeParams -} - -func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { - // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" - customId := midjRequest.CustomId - if customId == "" { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") - } - splits := strings.Split(customId, "::") - var action string - if splits[1] == "JOB" { - action = splits[2] - } else { - action = splits[1] - } - - if action == "" { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") - } - if strings.Contains(action, "upsample") { - index, err := strconv.Atoi(splits[3]) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") - } - midjRequest.Index = index - midjRequest.Action = constant.MjActionUpscale - } else if strings.Contains(action, "variation") { - midjRequest.Index = 1 - if action == "variation" { - index, err := strconv.Atoi(splits[3]) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") - } - midjRequest.Index = index - midjRequest.Action = constant.MjActionVariation - } else if action == "low_variation" { - midjRequest.Action = constant.MjActionLowVariation - } else if action == "high_variation" { - midjRequest.Action = constant.MjActionHighVariation - } - } else if strings.Contains(action, "pan") { - midjRequest.Action = constant.MjActionPan - midjRequest.Index = 1 - } else if action == "Outpaint" || action == "CustomZoom" { - midjRequest.Action = constant.MjActionZoom - midjRequest.Index = 1 - } else if action == "Inpaint" { - midjRequest.Action = constant.MjActionInPaintPre - midjRequest.Index = 1 - } else { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") - } - return nil -} diff --git a/service/midjourney.go b/service/midjourney.go new file mode 100644 index 0000000..06730c8 --- /dev/null +++ b/service/midjourney.go @@ -0,0 +1,135 @@ +package service + +import ( + "one-api/constant" + "one-api/dto" + relayconstant "one-api/relay/constant" + "strconv" + "strings" +) + +func CoverActionToModelName(mjAction string) string { + modelName := "mj_" + strings.ToLower(mjAction) + return modelName +} + +func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) { + action := "" + if relayMode == relayconstant.RelayModeMidjourneyAction { + // plus request + err := CoverPlusActionToNormalAction(midjRequest) + if err != nil { + return "", err, false + } + action = midjRequest.Action + } else { + switch relayMode { + case relayconstant.RelayModeMidjourneyImagine: + action = constant.MjActionImagine + case relayconstant.RelayModeMidjourneyDescribe: + action = constant.MjActionDescribe + case relayconstant.RelayModeMidjourneyBlend: + action = constant.MjActionBlend + case relayconstant.RelayModeMidjourneyShorten: + action = constant.MjActionShorten + case relayconstant.RelayModeMidjourneyChange: + action = midjRequest.Action + case relayconstant.RelayModeMidjourneyModal: + action = constant.MjActionInPaint + case relayconstant.RelayModeMidjourneySimpleChange: + params := ConvertSimpleChangeParams(midjRequest.Content) + if params == nil { + return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false + } + action = params.Action + case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify: + return "", nil, true + default: + return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false + } + } + modelName := CoverActionToModelName(action) + return modelName, nil, true +} + +func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { + // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" + customId := midjRequest.CustomId + if customId == "" { + return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") + } + splits := strings.Split(customId, "::") + var action string + if splits[1] == "JOB" { + action = splits[2] + } else { + action = splits[1] + } + + if action == "" { + return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + if strings.Contains(action, "upsample") { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionUpscale + } else if strings.Contains(action, "variation") { + midjRequest.Index = 1 + if action == "variation" { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionVariation + } else if action == "low_variation" { + midjRequest.Action = constant.MjActionLowVariation + } else if action == "high_variation" { + midjRequest.Action = constant.MjActionHighVariation + } + } else if strings.Contains(action, "pan") { + midjRequest.Action = constant.MjActionPan + midjRequest.Index = 1 + } else if action == "Outpaint" || action == "CustomZoom" { + midjRequest.Action = constant.MjActionZoom + midjRequest.Index = 1 + } else if action == "Inpaint" { + midjRequest.Action = constant.MjActionInPaintPre + midjRequest.Index = 1 + } else { + return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + return nil +} + +func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { + split := strings.Split(content, " ") + if len(split) != 2 { + return nil + } + + action := strings.ToLower(split[1]) + changeParams := &dto.MidjourneyRequest{} + changeParams.TaskId = split[0] + + if action[0] == 'u' { + changeParams.Action = "UPSCALE" + } else if action[0] == 'v' { + changeParams.Action = "VARIATION" + } else if action == "r" { + changeParams.Action = "REROLL" + return changeParams + } else { + return nil + } + + index, err := strconv.Atoi(action[1:2]) + if err != nil || index < 1 || index > 4 { + return nil + } + changeParams.Index = index + return changeParams +} diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index ee79368..225ce3f 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -96,10 +96,25 @@ const EditChannel = (props) => { localModels = ['glm-4', 'glm-4v', 'glm-3-turbo']; break; case 2: - localModels = ['midjourney']; + localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe']; break; case 5: - localModels = ['midjourney']; + localModels = [ + 'swap_face', + 'mj_imagine', + 'mj_variation', + 'mj_reroll', + 'mj_blend', + 'mj_upscale', + 'mj_describe', + 'mj_zoom', + 'mj_shorten', + 'mj_inpaint_pre', + 'mj_inpaint_pre', + 'mj_high_variation', + 'mj_low_variation', + 'mj_pan', + ]; break; } setInputs((inputs) => ({...inputs, models: localModels})); From d3399d68f6aae513709bc15646469c1ae652c93c Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 22:24:02 +0800 Subject: [PATCH 07/16] fix: fix typo --- web/src/components/MjLogsTable.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 603d345..88da1cd 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -39,7 +39,7 @@ function renderType(type) { return 平移; case 'DESCRIBE': return 图生文; - case 'BLEAND': + case 'BLEND': return 图混合; case 'SHORTEN': return 缩词; From 9b2e5c2978721ba6273a5c7b3b33ba9a2a1f3503 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 22:30:10 +0800 Subject: [PATCH 08/16] refactor: remove consumeQuota --- relay/relay-image.go | 77 ++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 42 deletions(-) diff --git a/relay/relay-image.go b/relay/relay-image.go index 3065496..aabe4ba 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") startTime := time.Now() var imageRequest dto.ImageRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &imageRequest) - if err != nil { - return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &imageRequest) + if err != nil { + return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if imageRequest.Model == "" { @@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N - if consumeQuota && userQuota-quota < 0 { + if userQuota-quota < 0 { return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } @@ -176,46 +173,42 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC var textResponse dto.ImageResponse defer func(ctx context.Context) { useTimeSeconds := time.Now().Unix() - startTime.Unix() - if consumeQuota { - if resp.StatusCode != http.StatusOK { - return - } - err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } + if resp.StatusCode != http.StatusOK { + return + } + err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) + responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) From 44361d75e88d066a8554dfeff384606d19df2d29 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 13 Mar 2024 23:17:12 +0800 Subject: [PATCH 09/16] fix: "Inpaint" code error --- relay/relay-mj.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 6cdd9e0..7342717 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -494,8 +494,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } //修改返回值 - newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) - responseBody = []byte(newBody) + if midjRequest.Action != constant.MjActionInPaintPre { + newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) + responseBody = []byte(newBody) + } } err = midjourneyTask.Insert() From a77fbc0fa214184391224c12340da09d63acce7f Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 00:43:32 +0800 Subject: [PATCH 10/16] fix: reroll action error --- model/ability.go | 7 ++++++- service/midjourney.go | 7 +++++-- web/src/pages/Channel/EditChannel.js | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/model/ability.go b/model/ability.go index 7a81cc2..b79978d 100644 --- a/model/ability.go +++ b/model/ability.go @@ -147,7 +147,12 @@ func FixAbility() (int, error) { return 0, err } var channels []Channel - err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error + + if len(abilityChannelIds) == 0 { + err = DB.Find(&channels).Error + } else { + err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error + } if err != nil { return 0, err } diff --git a/service/midjourney.go b/service/midjourney.go index 06730c8..c04c4d3 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -45,7 +45,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify: return "", nil, true default: - return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false + return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false } } modelName := CoverActionToModelName(action) @@ -93,6 +93,9 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj } else if strings.Contains(action, "pan") { midjRequest.Action = constant.MjActionPan midjRequest.Index = 1 + } else if strings.Contains(action, "reroll") { + midjRequest.Action = constant.MjActionReRoll + midjRequest.Index = 1 } else if action == "Outpaint" || action == "CustomZoom" { midjRequest.Action = constant.MjActionZoom midjRequest.Index = 1 @@ -100,7 +103,7 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj midjRequest.Action = constant.MjActionInPaintPre midjRequest.Index = 1 } else { - return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId) } return nil } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 225ce3f..757b56c 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -110,7 +110,7 @@ const EditChannel = (props) => { 'mj_zoom', 'mj_shorten', 'mj_inpaint_pre', - 'mj_inpaint_pre', + 'mj_inpaint', 'mj_high_variation', 'mj_low_variation', 'mj_pan', From 614220a0fb191c51e1f1bda5e9497536fe952e23 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 15:16:36 +0800 Subject: [PATCH 11/16] =?UTF-8?q?feat:=20=E8=B6=85=E8=BF=87=E4=B8=80?= =?UTF-8?q?=E5=B0=8F=E6=97=B6=E7=9A=84=E4=BB=BB=E5=8A=A1=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E5=A4=B1=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/midjourney.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/controller/midjourney.go b/controller/midjourney.go index 6256471..41db4bf 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -118,10 +118,16 @@ func UpdateMidjourneyTaskBulk() { for _, responseItem := range responseItems { task := taskM[responseItem.MjId] + + useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime + // 如果时间超过一小时,且进度不是100%,则认为任务失败 + if useTime > 3600000 && task.Progress != "100%" { + responseItem.FailReason = "上游任务超时(超过1小时)" + responseItem.Status = "FAILURE" + } if !checkMjTaskNeedUpdate(task, responseItem) { continue } - task.Code = 1 task.Progress = responseItem.Progress task.PromptEn = responseItem.PromptEn @@ -140,6 +146,7 @@ func UpdateMidjourneyTaskBulk() { buttonStr, _ := json.Marshal(responseItem.Buttons) task.Buttons = string(buttonStr) } + if task.Progress != "100%" && responseItem.FailReason != "" { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" From d704902b70d34dafcadc1f54cca75bb2c986f7ef Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 16:42:37 +0800 Subject: [PATCH 12/16] =?UTF-8?q?feat:=20=E5=85=BC=E5=AE=B9=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E5=8F=98=E7=84=A6=EF=BC=8C=E5=AE=8C=E5=96=84?= =?UTF-8?q?modal=E6=93=8D=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/model-ratio.go | 5 +- constant/midjourney.go | 6 +- controller/relay.go | 2 + dto/midjourney.go | 5 + middleware/distributor.go | 3 +- relay/constant/relay_mode.go | 3 + relay/relay-mj.go | 139 ++++++++------------------- router/relay-router.go | 1 + service/error.go | 7 ++ service/midjourney.go | 73 +++++++++++++- web/src/components/MjLogsTable.js | 10 +- web/src/pages/Channel/EditChannel.js | 3 +- 12 files changed, 147 insertions(+), 110 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 153b748..3231d95 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -100,13 +100,14 @@ var DefaultModelPrice = map[string]float64{ "mj_variation": 0.1, "mj_reroll": 0.1, "mj_blend": 0.1, - "mj_inpaint": 0.1, + "mj_modal": 0.1, "mj_zoom": 0.1, "mj_shorten": 0.1, "mj_high_variation": 0.1, "mj_low_variation": 0.1, "mj_pan": 0.1, - "mj_inpaint_pre": 0, + "mj_inpaint": 0, + "mj_custom_zoom": 0, "mj_describe": 0.05, "mj_upscale": 0.05, "swap_face": 0.05, diff --git a/constant/midjourney.go b/constant/midjourney.go index 92e2f23..f4ae2e4 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -13,8 +13,9 @@ const ( MjActionVariation = "VARIATION" MjActionReRoll = "REROLL" MjActionInPaint = "INPAINT" - MjActionInPaintPre = "INPAINT_PRE" + MjActionModal = "MODAL" MjActionZoom = "ZOOM" + MjActionCustomZoom = "CUSTOM_ZOOM" MjActionShorten = "SHORTEN" MjActionHighVariation = "HIGH_VARIATION" MjActionLowVariation = "LOW_VARIATION" @@ -29,9 +30,10 @@ var MidjourneyModel2Action = map[string]string{ "mj_upscale": MjActionUpscale, "mj_variation": MjActionVariation, "mj_reroll": MjActionReRoll, + "mj_modal": MjActionModal, "mj_inpaint": MjActionInPaint, - "mj_inpaint_pre": MjActionInPaintPre, "mj_zoom": MjActionZoom, + "mj_custom_zoom": MjActionCustomZoom, "mj_shorten": MjActionShorten, "mj_high_variation": MjActionHighVariation, "mj_low_variation": MjActionLowVariation, diff --git a/controller/relay.go b/controller/relay.go index fa5493a..e31679d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -67,6 +67,8 @@ func RelayMidjourney(c *gin.Context) { err = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: err = relay.RelayMidjourneyTask(c, relayMode) + case relayconstant.RelayModeMidjourneyTaskImageSeed: + err = relay.RelayMidjourneyTaskImageSeed(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } diff --git a/dto/midjourney.go b/dto/midjourney.go index f81756c..d3c3583 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -28,6 +28,11 @@ type MidjourneyResponse struct { Result string `json:"result"` } +type MidjourneyResponseWithStatusCode struct { + StatusCode int `json:"statusCode"` + Response MidjourneyResponse +} + type MidjourneyDto struct { MjId string `json:"id"` Action string `json:"action"` diff --git a/middleware/distributor.go b/middleware/distributor.go index c1b3ccb..ed457a3 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -48,7 +48,8 @@ func Distribute() func(c *gin.Context) { relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || - relayMode == relayconstant.RelayModeMidjourneyNotify { + relayMode == relayconstant.RelayModeMidjourneyNotify || + relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { shouldSelectChannel = false } else { midjourneyRequest := dto.MidjourneyRequest{} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 9f13726..197efdc 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -17,6 +17,7 @@ const ( RelayModeMidjourneySimpleChange RelayModeMidjourneyNotify RelayModeMidjourneyTaskFetch + RelayModeMidjourneyTaskImageSeed RelayModeMidjourneyTaskFetchByCondition RelayModeAudioSpeech RelayModeAudioTranscription @@ -77,6 +78,8 @@ func Path2RelayModeMidjourney(path string) int { relayMode = RelayModeMidjourneyChange } else if strings.HasSuffix(path, "/fetch") { relayMode = RelayModeMidjourneyTaskFetch + } else if strings.HasSuffix(path, "/image-seed") { + relayMode = RelayModeMidjourneyTaskImageSeed } else if strings.HasSuffix(path, "/list-by-condition") { relayMode = RelayModeMidjourneyTaskFetchByCondition } diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 7342717..8eebaeb 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -138,6 +138,31 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } +func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { + //taskId := c.Param("id") + //userId := c.GetInt("id") + //originTask := model.GetByMJId(userId, taskId) + //if originTask == nil { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") + //} + //channel, err := model.GetChannelById(originTask.ChannelId, false) + //if err != nil { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") + //} + //if channel.Status != common.ChannelStatusEnabled { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") + //} + //c.Set("channel_id", originTask.ChannelId) + //requestURL := c.Request.URL.String() + //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) + //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) + //if err != nil { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed") + //} + log.Println("RelayMidjourneyTaskImageSeed") + return nil +} + func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { userId := c.GetInt("id") var err error @@ -259,11 +284,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons mjId = params.TaskId midjRequest.Action = params.Action } else if relayMode == relayconstant.RelayModeMidjourneyModal { - if midjRequest.MaskBase64 == "" { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") - } + //if midjRequest.MaskBase64 == "" { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") + //} mjId = midjRequest.TaskId - midjRequest.Action = constant.MjActionInPaint + midjRequest.Action = constant.MjActionModal } originTask := model.GetByMJId(userId, mjId) @@ -293,29 +318,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //} } - if midjRequest.Action == constant.MjActionInPaintPre { + if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom { consumeQuota = false } - // map model name - //modelMapping := c.GetString("model_mapping") - //isModelMapped := false - //if modelMapping != "" { - // modelMap := make(map[string]string) - // err := json.Unmarshal([]byte(modelMapping), &modelMap) - // if err != nil { - // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - // return &dto.MidjourneyResponse{ - // Code: 4, - // Description: "unmarshal_model_mapping_failed", - // } - // } - // if modelMap[imageModel] != "" { - // imageModel = modelMap[imageModel] - // isModelMapped = true - // } - //} - //baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -325,20 +331,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - var requestBody io.Reader - //if isModelMapped { - // jsonStr, err := json.Marshal(midjRequest) - // if err != nil { - // return &dto.MidjourneyResponse{ - // Code: 4, - // Description: "marshal_text_request_failed", - // } - // } - // requestBody = bytes.NewBuffer(jsonStr) - //} else { - //} - requestBody = c.Request.Body - modelName := service.CoverActionToModelName(midjRequest.Action) modelPrice := common.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 @@ -368,40 +360,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest) if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "create_request_failed", - } + return &midjResponseWithStatus.Response } - //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey")) - timeout := time.Second * 30 - ctx, cancel := context.WithTimeout(context.Background(), timeout) - // 使用带有超时的 context 创建新的请求 - req = req.WithContext(ctx) - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) - // print request header - //log.Printf("request header: %s", req.Header) - //log.Printf("request body: %s", midjRequest.Prompt) - - defer cancel() - resp, err := service.GetHttpClient().Do(req) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed") - } - - err = req.Body.Close() - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed") - } - err = c.Request.Body.Close() - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed") - } - var midjResponse dto.MidjourneyResponse + midjResponse := &midjResponseWithStatus.Response defer func(ctx context.Context) { if consumeQuota { @@ -424,25 +387,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } }(c.Request.Context()) - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed") - } - err = resp.Body.Close() - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed") - } - if resp.StatusCode != 200 { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status") - } - err = json.Unmarshal(responseBody, &midjResponse) - log.Printf("responseBody: %s", string(responseBody)) - log.Printf("midjResponse: %v", midjResponse) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed") - } - // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md //1-提交成功 // 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}} @@ -494,7 +438,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } //修改返回值 - if midjRequest.Action != constant.MjActionInPaintPre { + if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom { newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) responseBody = []byte(newBody) } @@ -514,21 +458,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons responseBody = []byte(newBody) } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) + //for k, v := range resp.Header { + // c.Writer.Header().Set(k, v[0]) + //} + c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) + _, err = io.Copy(c.Writer, bodyReader) if err != nil { return &dto.MidjourneyResponse{ Code: 4, Description: "copy_response_body_failed", } } - err = resp.Body.Close() + err = bodyReader.Close() if err != nil { return &dto.MidjourneyResponse{ Code: 4, diff --git a/router/relay-router.go b/router/relay-router.go index f572d8f..3c6910a 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -57,6 +57,7 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.POST("/submit/blend", controller.RelayMidjourney) relayMjRouter.POST("/notify", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) + relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) } //relayMjRouter.Use() diff --git a/service/error.go b/service/error.go index 91c78c8..424be5d 100644 --- a/service/error.go +++ b/service/error.go @@ -18,6 +18,13 @@ func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse { } } +func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode { + return &dto.MidjourneyResponseWithStatusCode{ + StatusCode: statusCode, + Response: *MidjourneyErrorWrapper(code, desc), + } +} + // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { text := err.Error() diff --git a/service/midjourney.go b/service/midjourney.go index c04c4d3..17e54e1 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -1,11 +1,18 @@ package service import ( + "context" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "log" + "net/http" "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" "strconv" "strings" + "time" ) func CoverActionToModelName(mjAction string) string { @@ -35,7 +42,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin case relayconstant.RelayModeMidjourneyChange: action = midjRequest.Action case relayconstant.RelayModeMidjourneyModal: - action = constant.MjActionInPaint + action = constant.MjActionModal case relayconstant.RelayModeMidjourneySimpleChange: params := ConvertSimpleChangeParams(midjRequest.Content) if params == nil { @@ -96,11 +103,14 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj } else if strings.Contains(action, "reroll") { midjRequest.Action = constant.MjActionReRoll midjRequest.Index = 1 - } else if action == "Outpaint" || action == "CustomZoom" { + } else if action == "Outpaint" { midjRequest.Action = constant.MjActionZoom midjRequest.Index = 1 + } else if action == "CustomZoom" { + midjRequest.Action = constant.MjActionCustomZoom + midjRequest.Index = 1 } else if action == "Inpaint" { - midjRequest.Action = constant.MjActionInPaintPre + midjRequest.Action = constant.MjActionInPaint midjRequest.Index = 1 } else { return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId) @@ -136,3 +146,60 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { changeParams.Index = index return changeParams } + +func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { + var nullBytes []byte + var requestBody io.Reader + requestBody = c.Request.Body + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) + defer cancel() + resp, err := GetHttpClient().Do(req) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err + } + statusCode := resp.StatusCode + //if statusCode != 200 { + // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil + //} + err = req.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err + } + err = c.Request.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err + } + var midjResponse dto.MidjourneyResponse + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err + } + err = resp.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err + } + + err = json.Unmarshal(responseBody, &midjResponse) + log.Printf("responseBody: %s", string(responseBody)) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + } + //log.Printf("midjResponse: %v", midjResponse) + //for k, v := range resp.Header { + // c.Writer.Header().Set(k, v[0]) + //} + return &dto.MidjourneyResponseWithStatusCode{ + StatusCode: statusCode, + Response: midjResponse, + }, responseBody, nil +} diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 88da1cd..4843b2f 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -46,11 +46,13 @@ function renderType(type) { case 'REROLL': return 重绘; case 'INPAINT': - return 局部重绘; + return 局部重绘-提交; case 'ZOOM': return 变焦; - case 'INPAINT_PRE': - return 局部重绘-预处理; + case 'CUSTOM_ZOOM': + return 自定义变焦-提交; + case 'MODAL': + return 窗口处理; default: return 未知; } @@ -62,7 +64,7 @@ function renderCode(code) { case 1: return 已提交; case 21: - return 排队中; + return 等待中; case 22: return 重复提交; default: diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 757b56c..ccc18aa 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -109,8 +109,9 @@ const EditChannel = (props) => { 'mj_describe', 'mj_zoom', 'mj_shorten', - 'mj_inpaint_pre', + 'mj_modal', 'mj_inpaint', + 'mj_custom_zoom', 'mj_high_variation', 'mj_low_variation', 'mj_pan', From bc5a54df59ea464186390b95aa56cfbd5c605d55 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 16:59:46 +0800 Subject: [PATCH 13/16] feat: support image-seed (close #86) --- relay/relay-mj.go | 56 ++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 8eebaeb..6185d5b 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -139,27 +139,38 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo } func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { - //taskId := c.Param("id") - //userId := c.GetInt("id") - //originTask := model.GetByMJId(userId, taskId) - //if originTask == nil { - // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") - //} - //channel, err := model.GetChannelById(originTask.ChannelId, false) - //if err != nil { - // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") - //} - //if channel.Status != common.ChannelStatusEnabled { - // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") - //} - //c.Set("channel_id", originTask.ChannelId) - //requestURL := c.Request.URL.String() - //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) - //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) - //if err != nil { - // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed") - //} - log.Println("RelayMidjourneyTaskImageSeed") + taskId := c.Param("id") + userId := c.GetInt("id") + originTask := model.GetByMJId(userId, taskId) + if originTask == nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") + } + channel, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") + } + if channel.Status != common.ChannelStatusEnabled { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") + } + c.Set("channel_id", originTask.ChannelId) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + + requestURL := c.Request.URL.String() + fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) + midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil) + if err != nil { + return &midjResponseWithStatus.Response + } + midjResponse := &midjResponseWithStatus.Response + c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) + respBody, err := json.Marshal(midjResponse) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") + } return nil } @@ -297,7 +308,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 - channel, err := model.GetChannelById(originTask.ChannelId, false) + channel, err := model.GetChannelById(originTask.ChannelId, true) if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") } @@ -306,6 +317,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } c.Set("base_url", channel.GetBaseURL()) c.Set("channel_id", originTask.ChannelId) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt From 9b5353a81a4f993f519641615c960d712b546d1c Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 18:08:12 +0800 Subject: [PATCH 14/16] feat: support InsightFace (close #60) --- Midjourney.md | 10 ++- constant/midjourney.go | 4 +- controller/relay.go | 2 + dto/midjourney.go | 5 ++ relay/constant/relay_mode.go | 4 + relay/relay-mj.go | 129 +++++++++++++++++++++++++++++- router/relay-router.go | 1 + service/midjourney.go | 27 ++++++- web/src/components/MjLogsTable.js | 4 + 9 files changed, 173 insertions(+), 13 deletions(-) diff --git a/Midjourney.md b/Midjourney.md index d495e84..5733a11 100644 --- a/Midjourney.md +++ b/Midjourney.md @@ -19,8 +19,9 @@ - mj_zoom (比例变焦) - mj_shorten (提示词缩短) -- mj_inpaint_pre (发起局部重绘,必须和mj_inpaint一同添加) -- mj_inpaint (局部重绘提交,必须和mj_inpaint_pre一同添加) +- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加) +- mj_inpaint (局部重绘提交,必须和mj_modal一同添加) +- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加) - mj_high_variation (强变换) - mj_low_variation (弱变换) - mj_pan (平移) @@ -32,13 +33,14 @@ "mj_variation": 0.1, "mj_reroll": 0.1, "mj_blend": 0.1, - "mj_inpaint": 0.1, + "mj_modal": 0.1, "mj_zoom": 0.1, "mj_shorten": 0.1, "mj_high_variation": 0.1, "mj_low_variation": 0.1, "mj_pan": 0.1, - "mj_inpaint_pre": 0, + "mj_inpaint": 0, + "mj_custom_zoom": 0, "mj_describe": 0.05, "mj_upscale": 0.05, "swap_face": 0.05 diff --git a/constant/midjourney.go b/constant/midjourney.go index f4ae2e4..3d321ca 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -20,7 +20,7 @@ const ( MjActionHighVariation = "HIGH_VARIATION" MjActionLowVariation = "LOW_VARIATION" MjActionPan = "PAN" - SwapFace = "SWAP_FACE" + MjActionSwapFace = "SWAP_FACE" ) var MidjourneyModel2Action = map[string]string{ @@ -38,5 +38,5 @@ var MidjourneyModel2Action = map[string]string{ "mj_high_variation": MjActionHighVariation, "mj_low_variation": MjActionLowVariation, "mj_pan": MjActionPan, - "swap_face": SwapFace, + "swap_face": MjActionSwapFace, } diff --git a/controller/relay.go b/controller/relay.go index e31679d..9f866b8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -69,6 +69,8 @@ func RelayMidjourney(c *gin.Context) { err = relay.RelayMidjourneyTask(c, relayMode) case relayconstant.RelayModeMidjourneyTaskImageSeed: err = relay.RelayMidjourneyTaskImageSeed(c) + case relayconstant.RelayModeSwapFace: + err = relay.RelaySwapFace(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } diff --git a/dto/midjourney.go b/dto/midjourney.go index d3c3583..c675f7e 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -7,6 +7,11 @@ package dto // Content string `json:"content"` //} +type SwapFaceRequest struct { + SourceBase64 string `json:"sourceBase64"` + TargetBase64 string `json:"targetBase64"` +} + type MidjourneyRequest struct { Prompt string `json:"prompt"` CustomId string `json:"customId"` diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 197efdc..1790c57 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -25,6 +25,7 @@ const ( RelayModeMidjourneyAction RelayModeMidjourneyModal RelayModeMidjourneyShorten + RelayModeSwapFace ) func Path2RelayMode(path string) int { @@ -64,6 +65,9 @@ func Path2RelayModeMidjourney(path string) int { } else if strings.HasPrefix(path, "/mj/submit/shorten") { // midjourney plus relayMode = RelayModeMidjourneyShorten + } else if strings.HasPrefix(path, "/mj/insight-face/swap") { + // midjourney plus + relayMode = RelayModeSwapFace } else if strings.HasPrefix(path, "/mj/submit/imagine") { relayMode = RelayModeMidjourneyImagine } else if strings.HasPrefix(path, "/mj/submit/blend") { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 6185d5b..01ae0c8 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } +func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { + startTime := time.Now().UnixNano() / int64(time.Millisecond) + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + group := c.GetString("group") + channelId := c.GetInt("channel_id") + var swapFaceRequest dto.SwapFaceRequest + err := common.UnmarshalBodyReusable(c, &swapFaceRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") + } + if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") + } + modelName := service.CoverActionToModelName(constant.MjActionSwapFace) + modelPrice := common.GetModelPrice(modelName, true) + // 如果没有配置价格,则使用默认价格 + if modelPrice == -1 { + defaultPrice, ok := common.DefaultModelPrice[modelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + groupRatio := common.GetGroupRatio(group) + ratio := modelPrice * groupRatio + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return &dto.MidjourneyResponse{ + Code: 4, + Description: err.Error(), + } + } + quota := int(ratio * common.QuotaPerUnit) + + if userQuota-quota < 0 { + return &dto.MidjourneyResponse{ + Code: 4, + Description: "quota_not_enough", + } + } + requestURL := c.Request.URL.String() + baseURL := c.GetString("base_url") + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL) + if err != nil { + return &mjResp.Response + } + defer func(ctx context.Context) { + if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { + err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + } + }(c.Request.Context()) + midjResponse := &mjResp.Response + midjourneyTask := &model.Midjourney{ + UserId: userId, + Code: midjResponse.Code, + Action: constant.MjActionSwapFace, + MjId: midjResponse.Result, + Prompt: "swap_face", + PromptEn: "", + Description: midjResponse.Description, + State: "", + SubmitTime: startTime, + StartTime: time.Now().UnixNano() / int64(time.Millisecond), + FinishTime: 0, + ImageUrl: "", + Status: "", + Progress: "0%", + FailReason: "", + ChannelId: c.GetInt("channel_id"), + Quota: quota, + } + err = midjourneyTask.Insert() + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed") + } + c.Writer.WriteHeader(mjResp.StatusCode) + respBody, err := json.Marshal(midjResponse) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") + } + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") + } + return nil +} + func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { taskId := c.Param("id") userId := c.GetInt("id") @@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { requestURL := c.Request.URL.String() fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) - midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil) + midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) if err != nil { return &midjResponseWithStatus.Response } + //defer func(ctx context.Context) { + // err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) + // if err != nil { + // common.SysError("error consuming token remain quota: " + err.Error()) + // } + // err = model.CacheUpdateUserQuota(userId) + // if err != nil { + // common.SysError("error update user quota cache: " + err.Error()) + // } + // if quota != 0 { + // tokenName := c.GetString("token_name") + // logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action) + // model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false) + // model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + // channelId := c.GetInt("channel_id") + // model.UpdateChannelUsedQuota(channelId, quota) + // } + //}(c.Request.Context()) midjResponse := &midjResponseWithStatus.Response c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) respBody, err := json.Marshal(midjResponse) @@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } - midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest) + midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) if err != nil { return &midjResponseWithStatus.Response } midjResponse := &midjResponseWithStatus.Response defer func(ctx context.Context) { - if consumeQuota { + if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) diff --git a/router/relay-router.go b/router/relay-router.go index 3c6910a..4addee0 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -59,6 +59,7 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) + relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney) } //relayMjRouter.Use() } diff --git a/service/midjourney.go b/service/midjourney.go index 17e54e1..7c47cd6 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -17,6 +17,9 @@ import ( func CoverActionToModelName(mjAction string) string { modelName := "mj_" + strings.ToLower(mjAction) + if mjAction == constant.MjActionSwapFace { + modelName = "swap_face" + } return modelName } @@ -43,6 +46,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin action = midjRequest.Action case relayconstant.RelayModeMidjourneyModal: action = constant.MjActionModal + case relayconstant.RelayModeSwapFace: + action = constant.MjActionSwapFace case relayconstant.RelayModeMidjourneySimpleChange: params := ConvertSimpleChangeParams(midjRequest.Content) if params == nil { @@ -147,11 +152,25 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { return changeParams } -func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { +func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { var nullBytes []byte - var requestBody io.Reader - requestBody = c.Request.Body - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + //var requestBody io.Reader + //requestBody = c.Request.Body + // read request body to json, delete accountFilter and notifyHook + var mapResult map[string]interface{} + err := json.NewDecoder(c.Request.Body).Decode(&mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + delete(mapResult, "accountFilter") + delete(mapResult, "notifyHook") + //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + // make new request with mapResult + reqBody, err := json.Marshal(mapResult) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody))) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err } diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 4843b2f..90c55f1 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -53,6 +53,8 @@ function renderType(type) { return 自定义变焦-提交; case 'MODAL': return 窗口处理; + case 'SWAP_FACE': + return 换脸; default: return 未知; } @@ -67,6 +69,8 @@ function renderCode(code) { return 等待中; case 22: return 重复提交; + case 0: + return 未提交; default: return 未知; } From 2786a6b53931230d128502a6d2b01d18c798e13d Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 18:10:09 +0800 Subject: [PATCH 15/16] Update README.md --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f1d18da..ce3f27f 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ 此分叉版本的主要变更如下: 1. 全新的UI界面(部分界面还待更新) -2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持 +2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持 + [x] /mj/submit/imagine + [x] /mj/submit/change + [x] /mj/submit/blend @@ -26,6 +26,11 @@ + [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**) + [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址) + [x] /task/list-by-condition + + [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同) + + [x] /mj/submit/modal + + [x] /mj/submit/shorten + + [x] /mj/task/{id}/image-seed + + [x] /mj/insight-face/swap (InsightFace) 3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口: + [x] 易支付 4. 支持用key查询使用额度: @@ -49,6 +54,7 @@ 2. 智谱glm-4v,glm-4v识图 3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229) 4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改 +5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 From 84e0544604cc3a23bba80bd834b5314710fab1ee Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 14 Mar 2024 18:19:22 +0800 Subject: [PATCH 16/16] =?UTF-8?q?refactor:=20=E4=BF=AE=E6=94=B9=E8=B6=85?= =?UTF-8?q?=E6=97=B6=E6=97=B6=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/relay-mj.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 01ae0c8..35353b4 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -183,7 +183,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { requestURL := c.Request.URL.String() baseURL := c.GetString("base_url") fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL) + mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) if err != nil { return &mjResp.Response } @@ -213,7 +213,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { Code: midjResponse.Code, Action: constant.MjActionSwapFace, MjId: midjResponse.Result, - Prompt: "swap_face", + Prompt: "InsightFace", PromptEn: "", Description: midjResponse.Description, State: "", @@ -495,7 +495,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } - midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL) + midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL) if err != nil { return &midjResponseWithStatus.Response }