package service import ( "context" "encoding/json" "github.com/gin-gonic/gin" "io" "log" "net/http" "one-api/common" "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" "strconv" "strings" "time" ) func CoverActionToModelName(mjAction string) string { modelName := "mj_" + strings.ToLower(mjAction) if mjAction == constant.MjActionSwapFace { modelName = "swap_face" } return modelName } func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) { action := "" if relayMode == relayconstant.RelayModeMidjourneyAction { // plus request err := CoverPlusActionToNormalAction(midjRequest) if err != nil { return "", err, false } action = midjRequest.Action } else { switch relayMode { case relayconstant.RelayModeMidjourneyImagine: action = constant.MjActionImagine case relayconstant.RelayModeMidjourneyDescribe: action = constant.MjActionDescribe case relayconstant.RelayModeMidjourneyBlend: action = constant.MjActionBlend case relayconstant.RelayModeMidjourneyShorten: action = constant.MjActionShorten case relayconstant.RelayModeMidjourneyChange: action = midjRequest.Action case relayconstant.RelayModeMidjourneyModal: action = constant.MjActionModal case relayconstant.RelayModeSwapFace: action = constant.MjActionSwapFace case relayconstant.RelayModeMidjourneySimpleChange: params := ConvertSimpleChangeParams(midjRequest.Content) if params == nil { return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false } action = params.Action case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify: return "", nil, true default: return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false } } modelName := CoverActionToModelName(action) return modelName, nil, true } func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse { // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011" customId := midjRequest.CustomId if customId == "" { return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required") } splits := strings.Split(customId, "::") var action string if splits[1] == "JOB" { action = splits[2] } else { action = splits[1] } if action == "" { return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action") } if strings.Contains(action, "upsample") { index, err := strconv.Atoi(splits[3]) if err != nil { return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") } midjRequest.Index = index midjRequest.Action = constant.MjActionUpscale } else if strings.Contains(action, "variation") { midjRequest.Index = 1 if action == "variation" { index, err := strconv.Atoi(splits[3]) if err != nil { return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed") } midjRequest.Index = index midjRequest.Action = constant.MjActionVariation } else if action == "low_variation" { midjRequest.Action = constant.MjActionLowVariation } else if action == "high_variation" { midjRequest.Action = constant.MjActionHighVariation } } else if strings.Contains(action, "pan") { midjRequest.Action = constant.MjActionPan midjRequest.Index = 1 } else if strings.Contains(action, "reroll") { midjRequest.Action = constant.MjActionReRoll midjRequest.Index = 1 } else if action == "Outpaint" { midjRequest.Action = constant.MjActionZoom midjRequest.Index = 1 } else if action == "CustomZoom" { midjRequest.Action = constant.MjActionCustomZoom midjRequest.Index = 1 } else if action == "Inpaint" { midjRequest.Action = constant.MjActionInPaint midjRequest.Index = 1 } else { return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId) } return nil } func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { split := strings.Split(content, " ") if len(split) != 2 { return nil } action := strings.ToLower(split[1]) changeParams := &dto.MidjourneyRequest{} changeParams.TaskId = split[0] if action[0] == 'u' { changeParams.Action = "UPSCALE" } else if action[0] == 'v' { changeParams.Action = "VARIATION" } else if action == "r" { changeParams.Action = "REROLL" return changeParams } else { return nil } index, err := strconv.Atoi(action[1:2]) if err != nil || index < 1 || index > 4 { return nil } changeParams.Index = index return changeParams } func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { var nullBytes []byte //var requestBody io.Reader //requestBody = c.Request.Body // read request body to json, delete accountFilter and notifyHook var mapResult map[string]interface{} // if get request, no need to read request body if c.Request.Method != "GET" { err := json.NewDecoder(c.Request.Body).Decode(&mapResult) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err } delete(mapResult, "accountFilter") if !constant.MjNotifyEnabled { delete(mapResult, "notifyHook") } //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) // make new request with mapResult } if constant.MjModeClearEnabled { if prompt, ok := mapResult["prompt"].(string); ok { prompt = strings.Replace(prompt, "--fast", "", -1) prompt = strings.Replace(prompt, "--relax", "", -1) prompt = strings.Replace(prompt, "--turbo", "", -1) mapResult["prompt"] = prompt } } reqBody, err := json.Marshal(mapResult) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err } req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody))) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err } ctx, cancel := context.WithTimeout(context.Background(), timeout) // 使用带有超时的 context 创建新的请求 req = req.WithContext(ctx) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) auth := c.Request.Header.Get("Authorization") if auth != "" { auth = strings.TrimPrefix(auth, "Bearer ") req.Header.Set("mj-api-secret", auth) } defer cancel() resp, err := GetHttpClient().Do(req) if err != nil { common.SysError("do request failed: " + err.Error()) return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err } statusCode := resp.StatusCode //if statusCode != 200 { // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil //} err = req.Body.Close() if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err } err = c.Request.Body.Close() if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err } var midjResponse dto.MidjourneyResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err } err = resp.Body.Close() if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err } respStr := string(responseBody) log.Printf("responseBody: %s", respStr) if respStr == "" { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil } else { err = json.Unmarshal(responseBody, &midjResponse) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err } } //log.Printf("midjResponse: %v", midjResponse) //for k, v := range resp.Header { // c.Writer.Header().Set(k, v[0]) //} return &dto.MidjourneyResponseWithStatusCode{ StatusCode: statusCode, Response: midjResponse, }, responseBody, nil }