package service import ( "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" "strconv" "strings" ) func CoverActionToModelName(mjAction string) string { modelName := "mj_" + strings.ToLower(mjAction) return modelName } func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) { action := "" if relayMode == relayconstant.RelayModeMidjourneyAction { // plus request err := CoverPlusActionToNormalAction(midjRequest) if err != nil { return "", err, false } action = midjRequest.Action } else { switch relayMode { case relayconstant.RelayModeMidjourneyImagine: action = constant.MjActionImagine case relayconstant.RelayModeMidjourneyDescribe: action = constant.MjActionDescribe case relayconstant.RelayModeMidjourneyBlend: action = constant.MjActionBlend case relayconstant.RelayModeMidjourneyShorten: action = constant.MjActionShorten case relayconstant.RelayModeMidjourneyChange: action = midjRequest.Action case relayconstant.RelayModeMidjourneyModal: action = constant.MjActionInPaint case relayconstant.RelayModeMidjourneySimpleChange: params := ConvertSimpleChangeParams(midjRequest.Content) if params == nil { return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false } action = params.Action case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify: return "", nil, true default: return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false } } modelName := CoverActionToModelName(action) return modelName, nil, true } func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" customId := midjRequest.CustomId if customId == "" { return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") } splits := strings.Split(customId, "::") var action string if splits[1] == "JOB" { action = splits[2] } else { action = splits[1] } if action == "" { return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") } if strings.Contains(action, "upsample") { index, err := strconv.Atoi(splits[3]) if err != nil { return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") } midjRequest.Index = index midjRequest.Action = constant.MjActionUpscale } else if strings.Contains(action, "variation") { midjRequest.Index = 1 if action == "variation" { index, err := strconv.Atoi(splits[3]) if err != nil { return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") } midjRequest.Index = index midjRequest.Action = constant.MjActionVariation } else if action == "low_variation" { midjRequest.Action = constant.MjActionLowVariation } else if action == "high_variation" { midjRequest.Action = constant.MjActionHighVariation } } else if strings.Contains(action, "pan") { midjRequest.Action = constant.MjActionPan midjRequest.Index = 1 } else if action == "Outpaint" || action == "CustomZoom" { midjRequest.Action = constant.MjActionZoom midjRequest.Index = 1 } else if action == "Inpaint" { midjRequest.Action = constant.MjActionInPaintPre midjRequest.Index = 1 } else { return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") } return nil } func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { split := strings.Split(content, " ") if len(split) != 2 { return nil } action := strings.ToLower(split[1]) changeParams := &dto.MidjourneyRequest{} changeParams.TaskId = split[0] if action[0] == 'u' { changeParams.Action = "UPSCALE" } else if action[0] == 'v' { changeParams.Action = "VARIATION" } else if action == "r" { changeParams.Action = "REROLL" return changeParams } else { return nil } index, err := strconv.Atoi(action[1:2]) if err != nil || index < 1 || index > 4 { return nil } changeParams.Index = index return changeParams }