diff --git a/common/constants.go b/common/constants.go index 98fa67a..cbb7861 100644 --- a/common/constants.go +++ b/common/constants.go @@ -189,7 +189,7 @@ const ( ChannelTypeMidjourney = 2 ChannelTypeAzure = 3 ChannelTypeOllama = 4 - ChannelTypeOpenAISB = 5 + ChannelTypeMidjourneyPlus = 5 ChannelTypeOpenAIMax = 6 ChannelTypeOhMyGPT = 7 ChannelTypeCustom = 8 diff --git a/constant/midjourney.go b/constant/midjourney.go new file mode 100644 index 0000000..dbcc5c8 --- /dev/null +++ b/constant/midjourney.go @@ -0,0 +1,16 @@ +package constant + +const ( + MjErrorUnknown = 5 + MjRequestError = 4 +) + +const ( + MjActionImagine = "IMAGINE" + MjActionDescribe = "DESCRIBE" + MjActionBlend = "BLEND" + MjActionUpscale = "UPSCALE" + MjActionVariation = "VARIATION" + MjActionInPaint = "INPAINT" + MjActionInPaintPre = "INPAINT_PRE" +) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 4bcd4d4..96f82ee 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return 0, errors.New("尚未实现") case common.ChannelTypeCustom: baseURL = channel.GetBaseURL() - case common.ChannelTypeOpenAISB: - return updateChannelOpenAISBBalance(channel) + //case common.ChannelTypeOpenAISB: + // return updateChannelOpenAISBBalance(channel) case common.ChannelTypeAIProxy: return updateChannelAIProxyBalance(channel) case common.ChannelTypeAPI2GPT: diff --git a/controller/midjourney.go b/controller/midjourney.go index 1a42270..cac253c 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -10,8 +10,8 @@ import ( "log" "net/http" "one-api/common" + "one-api/dto" "one-api/model" - relay2 "one-api/relay" "one-api/service" "strconv" "strings" @@ -75,11 +75,11 @@ import ( responseBody, err := io.ReadAll(resp.Body) resp.Body.Close() log.Printf("responseBody: %s", string(responseBody)) - var responseItem Midjourney + var responseItem MidjourneyDto // err = json.NewDecoder(resp.Body).Decode(&responseItem) err = json.Unmarshal(responseBody, &responseItem) if err != nil { - if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") { + if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") { var responseWithoutStatus MidjourneyWithoutStatus var responseStatus MidjourneyStatus err1 := json.Unmarshal(responseBody, &responseWithoutStatus) @@ -228,12 +228,16 @@ func UpdateMidjourneyTaskBulk() { common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } + if resp.StatusCode != http.StatusOK { + common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + continue + } responseBody, err := io.ReadAll(resp.Body) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } - var responseItems []relay2.Midjourney + var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) @@ -259,6 +263,10 @@ func UpdateMidjourneyTaskBulk() { task.ImageUrl = responseItem.ImageUrl task.Status = responseItem.Status task.FailReason = responseItem.FailReason + if responseItem.Buttons != nil { + buttonStr, _ := json.Marshal(responseItem.Buttons) + task.Buttons = string(buttonStr) + } if task.Progress != "100%" && responseItem.FailReason != "" { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" @@ -286,7 +294,7 @@ func UpdateMidjourneyTaskBulk() { } } -func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool { +func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool { if oldTask.Code != 1 { return true } diff --git a/controller/relay.go b/controller/relay.go index 911a7c5..a42db2e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,7 +62,13 @@ func Relay(c *gin.Context) { func RelayMidjourney(c *gin.Context) { relayMode := relayconstant.RelayModeUnknown - if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { + if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/action") { + // midjourney plus + relayMode = relayconstant.RelayModeMidjourneyAction + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") { + // midjourney plus + relayMode = relayconstant.RelayModeMidjourneyModal + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { relayMode = relayconstant.RelayModeMidjourneyImagine } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") { relayMode = relayconstant.RelayModeMidjourneyBlend @@ -86,35 +92,24 @@ func RelayMidjourney(c *gin.Context) { err = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: err = relay.RelayMidjourneyTask(c, relayMode) + //case relayconstant.RelayModeMidjourneyModal: + // err = relay.RelayMidjournneyModal(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } //err = relayMidjourneySubmit(c, relayMode) log.Println(err) if err != nil { - retryTimesStr := c.Query("retry") - retryTimes, _ := strconv.Atoi(retryTimesStr) - if retryTimesStr == "" { - retryTimes = common.RetryTimes - } - if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) - } else { - if err.Code == 30 { - err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" - } - c.JSON(429, gin.H{ - "error": fmt.Sprintf("%s %s", err.Description, err.Result), - "type": "upstream_error", - }) + if err.Code == 30 { + err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" } + c.JSON(429, gin.H{ + "error": fmt.Sprintf("%s %s", err.Description, err.Result), + "type": "upstream_error", + "code": err.Code, + }) channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) - //if shouldDisableChannel(&err.Error) { - // channelId := c.GetInt("channel_id") - // channelName := c.GetString("channel_name") - // disableChannel(channelId, channelName, err.Result) - //};'''''''''''''''''''''''''''''''' } } diff --git a/dto/midjourney.go b/dto/midjourney.go index 4c67909..a16a65e 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -2,6 +2,8 @@ package dto type MidjourneyRequest struct { Prompt string `json:"prompt"` + CustomId string `json:"customId"` + BotType string `json:"botType"` NotifyHook string `json:"notifyHook"` Action string `json:"action"` Index int `json:"index"` @@ -9,6 +11,7 @@ type MidjourneyRequest struct { TaskId string `json:"taskId"` Base64Array []string `json:"base64Array"` Content string `json:"content"` + MaskBase64 string `json:"maskBase64"` } type MidjourneyResponse struct { @@ -17,3 +20,52 @@ type MidjourneyResponse struct { Properties interface{} `json:"properties"` Result string `json:"result"` } + +type MidjourneyDto struct { + MjId string `json:"id"` + Action string `json:"action"` + CustomId string `json:"customId"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + ImageUrl string `json:"imageUrl"` + Status string `json:"status"` + Progress string `json:"progress"` + FailReason string `json:"failReason"` + Buttons any `json:"buttons"` + MaskBase64 string `json:"maskBase64"` +} + +type MidjourneyStatus struct { + Status int `json:"status"` +} +type MidjourneyWithoutStatus struct { + Id int `json:"id"` + Code int `json:"code"` + UserId int `json:"user_id" gorm:"index"` + Action string `json:"action"` + MjId string `json:"mj_id" gorm:"index"` + Prompt string `json:"prompt"` + PromptEn string `json:"prompt_en"` + Description string `json:"description"` + State string `json:"state"` + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + ImageUrl string `json:"image_url"` + Progress string `json:"progress"` + FailReason string `json:"fail_reason"` + ChannelId int `json:"channel_id"` +} + +type ActionButton struct { + CustomId any `json:"customId"` + Emoji any `json:"emoji"` + Label any `json:"label"` + Type any `json:"type"` + Style any `json:"style"` +} diff --git a/middleware/auth.go b/middleware/auth.go index ef774f6..a8dac30 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -125,12 +125,6 @@ func TokenAuth() func(c *gin.Context) { } else { c.Set("token_model_limit_enabled", false) } - requestURL := c.Request.URL.String() - consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { - consumeQuota = false - } - c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) diff --git a/model/midjourney.go b/model/midjourney.go index 0ef2e55..f20ab32 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -19,6 +19,7 @@ type Midjourney struct { FailReason string `json:"fail_reason"` ChannelId int `json:"channel_id"` Quota int `json:"quota"` + Buttons string `json:"buttons"` } // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index beea7dc..c49caae 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -21,6 +21,8 @@ const ( RelayModeAudioSpeech RelayModeAudioTranscription RelayModeAudioTranslation + RelayModeMidjourneyAction + RelayModeMidjourneyModal ) func Path2RelayMode(path string) int { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index b2b9926..f667cd1 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -9,6 +9,7 @@ import ( "log" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" relayconstant "one-api/relay/constant" @@ -20,51 +21,15 @@ import ( "github.com/gin-gonic/gin" ) -type Midjourney struct { - MjId string `json:"id"` - Action string `json:"action"` - Prompt string `json:"prompt"` - PromptEn string `json:"promptEn"` - Description string `json:"description"` - State string `json:"state"` - SubmitTime int64 `json:"submitTime"` - StartTime int64 `json:"startTime"` - FinishTime int64 `json:"finishTime"` - ImageUrl string `json:"imageUrl"` - Status string `json:"status"` - Progress string `json:"progress"` - FailReason string `json:"failReason"` -} - -type MidjourneyStatus struct { - Status int `json:"status"` -} -type MidjourneyWithoutStatus struct { - Id int `json:"id"` - Code int `json:"code"` - UserId int `json:"user_id" gorm:"index"` - Action string `json:"action"` - MjId string `json:"mj_id" gorm:"index"` - Prompt string `json:"prompt"` - PromptEn string `json:"prompt_en"` - Description string `json:"description"` - State string `json:"state"` - SubmitTime int64 `json:"submit_time"` - StartTime int64 `json:"start_time"` - FinishTime int64 `json:"finish_time"` - ImageUrl string `json:"image_url"` - Progress string `json:"progress"` - FailReason string `json:"fail_reason"` - ChannelId int `json:"channel_id"` -} - var DefaultModelPrice = map[string]float64{ - "mj_imagine": 0.1, - "mj_variation": 0.1, - "mj_reroll": 0.1, - "mj_blend": 0.1, - "mj_describe": 0.05, - "mj_upscale": 0.05, + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_inpaint": 0.1, + "mj_inpaint_pre": 0, + "mj_describe": 0.05, + "mj_upscale": 0.05, } func RelayMidjourneyImage(c *gin.Context) { @@ -108,7 +73,7 @@ func RelayMidjourneyImage(c *gin.Context) { } func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { - var midjRequest Midjourney + var midjRequest dto.MidjourneyDto err := common.UnmarshalBodyReusable(c, &midjRequest) if err != nil { return &dto.MidjourneyResponse{ @@ -147,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { return nil } -func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) { +func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) { midjourneyTask.MjId = originTask.MjId midjourneyTask.Progress = originTask.Progress midjourneyTask.PromptEn = originTask.PromptEn @@ -167,9 +132,41 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.Action = originTask.Action midjourneyTask.Description = originTask.Description midjourneyTask.Prompt = originTask.Prompt + if originTask.Buttons != "" { + var buttons []dto.ActionButton + err := json.Unmarshal([]byte(originTask.Buttons), &buttons) + if err == nil { + midjourneyTask.Buttons = buttons + } + } return } +func RelayMidjournneyModal(c *gin.Context) *dto.MidjourneyResponse { + userId := c.GetInt("id") + var midjRequest dto.MidjourneyRequest + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") + } + originTask := model.GetByMJId(userId, midjRequest.TaskId) + if originTask == nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") + } + + respBody, err := json.Marshal(midjRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") + } + c.Writer.Header().Set("Content-Type", "application/json") + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") + } + return nil + +} + func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { userId := c.GetInt("id") var err error @@ -184,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse Description: "task_no_found", } } - midjourneyTask := getMidjourneyTaskModel(c, originTask) + midjourneyTask := getMidjourneyTaskDto(c, originTask) respBody, err = json.Marshal(midjourneyTask) if err != nil { return &dto.MidjourneyResponse{ @@ -203,16 +200,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse Description: "do_request_failed", } } - var tasks []Midjourney + var tasks []dto.MidjourneyDto if len(condition.IDs) != 0 { originTasks := model.GetByMJIds(userId, condition.IDs) for _, originTask := range originTasks { - midjourneyTask := getMidjourneyTaskModel(c, originTask) + midjourneyTask := getMidjourneyTaskDto(c, originTask) tasks = append(tasks, midjourneyTask) } } if tasks == nil { - tasks = make([]Midjourney, 0) + tasks = make([]dto.MidjourneyDto, 0) } respBody, err = json.Marshal(tasks) if err != nil { @@ -235,44 +232,32 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse return nil } -const ( - // type 1 根据 mode 价格不同 - MJSubmitActionImagine = "IMAGINE" - MJSubmitActionVariation = "VARIATION" //变换 - MJSubmitActionBlend = "BLEND" //混图 - - MJSubmitActionReroll = "REROLL" //重新生成 - // type 2 固定价格 - MJSubmitActionDescribe = "DESCRIBE" - MJSubmitActionUpscale = "UPSCALE" // 放大 -) - func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { imageModel := "midjourney" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") channelId := c.GetInt("channel_id") + consumeQuota := true var midjRequest dto.MidjourneyRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &midjRequest) - if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "bind_request_body_failed", - } + err := common.UnmarshalBodyReusable(c, &midjRequest) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") + } + + if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 + mjErr := coverPlusActionToNormalAction(&midjRequest) + if mjErr != nil { + return mjErr } + relayMode = relayconstant.RelayModeMidjourneyChange } if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "prompt_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } midjRequest.Action = "IMAGINE" } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 @@ -283,71 +268,58 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons mjId := "" if relayMode == relayconstant.RelayModeMidjourneyChange { if midjRequest.TaskId == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "taskId_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") } else if midjRequest.Action == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "action_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required") } else if midjRequest.Index == 0 { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "index_can_only_be_1_2_3_4", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required") } //action = midjRequest.Action mjId = midjRequest.TaskId } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange { if midjRequest.Content == "" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "content_is_required", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } params := convertSimpleChangeParams(midjRequest.Content) if params == nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "content_parse_failed", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed") } mjId = params.ID midjRequest.Action = params.Action + } else if relayMode == relayconstant.RelayModeMidjourneyModal { + if midjRequest.MaskBase64 == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") + } + mjId = midjRequest.TaskId + midjRequest.Action = "INPAINT" } originTask := model.GetByMJId(userId, mjId) if originTask == nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "task_no_found", - } - } else if originTask.Action == "UPSCALE" { - //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest). - return &dto.MidjourneyResponse{ - Code: 4, - Description: "upscale_task_can_not_be_change", - } - } else if originTask.Status != "SUCCESS" { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "task_status_is_not_success", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") + } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 channel, err := model.GetChannelById(originTask.ChannelId, false) if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "channel_not_found", - } + return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") } c.Set("base_url", channel.GetBaseURL()) c.Set("channel_id", originTask.ChannelId) - log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) + log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt + + if channelType == common.ChannelTypeMidjourneyPlus { + // plus + } else { + // 普通版渠道 + + } + } + + if midjRequest.Action == constant.MjActionInPaintPre { + consumeQuota = false } // map model name @@ -379,7 +351,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify" fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - log.Printf("fullRequestURL: %s", fullRequestURL) var requestBody io.Reader if isModelMapped { @@ -394,6 +365,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } else { requestBody = c.Request.Body } + mjAction := "mj_" + strings.ToLower(midjRequest.Action) modelPrice := common.GetModelPrice(mjAction, true) // 如果没有配置价格,则使用默认价格 @@ -489,9 +461,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } }(c.Request.Context()) - //if consumeQuota { - // - //} responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -651,3 +620,43 @@ func convertSimpleChangeParams(content string) *taskChangeParams { changeParams.Index = index return changeParams } + +func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { + // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" + customId := midjRequest.CustomId + if customId == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") + } + splits := strings.Split(customId, "::") + var action string + if splits[1] == "JOB" { + action = splits[2] + } else { + action = splits[1] + } + + if action == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + if strings.Contains(action, "upsample") { + index, err := strconv.Atoi(splits[3]) + if err != nil { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") + } + midjRequest.Index = index + midjRequest.Action = constant.MjActionUpscale + } else if strings.Contains(action, "variation") { + midjRequest.Action = constant.MjActionVariation + } else if strings.Contains(action, "pan") { + midjRequest.Action = constant.MjActionVariation + midjRequest.Index = 1 + } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") { + midjRequest.Action = constant.MjActionInPaintPre + } else if action == "Inpaint" { + midjRequest.Action = constant.MjActionInPaintPre + midjRequest.Index = 1 + } else { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") + } + return nil +} diff --git a/router/relay-router.go b/router/relay-router.go index 6a30a5a..68b762b 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -47,6 +47,8 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage) relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { + relayMjRouter.POST("/submit/action", controller.RelayMidjourney) + relayMjRouter.POST("/submit/modal", controller.RelayMidjourney) relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) relayMjRouter.POST("/submit/change", controller.RelayMidjourney) relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney) diff --git a/service/error.go b/service/error.go index 303bcf7..91c78c8 100644 --- a/service/error.go +++ b/service/error.go @@ -11,6 +11,13 @@ import ( "strings" ) +func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse { + return &dto.MidjourneyResponse{ + Code: code, + Description: desc, + } +} + // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { text := err.Error() diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 1f71208..4f17c14 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -35,6 +35,10 @@ function renderType(type) { return 图生文; case 'BLEAND': return 图混合; + case 'INPAINT': + return 局部重绘; + case 'INPAINT_PRE': + return 局部重绘-预处理; default: return 未知; } @@ -68,6 +72,8 @@ function renderStatus(type) { return 执行中; case 'FAILURE': return 失败; + case 'MODAL': + return 窗口等待; default: return 未知; } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index a641a02..bb18d0d 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -1,6 +1,7 @@ export const CHANNEL_OPTIONS = [ {key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'}, {key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'}, + {key: 5, text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus'}, {key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'}, {key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'}, {key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'},