diff --git a/common/model-ratio.go b/common/model-ratio.go index 153b748..3231d95 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -100,13 +100,14 @@ var DefaultModelPrice = map[string]float64{ "mj_variation": 0.1, "mj_reroll": 0.1, "mj_blend": 0.1, - "mj_inpaint": 0.1, + "mj_modal": 0.1, "mj_zoom": 0.1, "mj_shorten": 0.1, "mj_high_variation": 0.1, "mj_low_variation": 0.1, "mj_pan": 0.1, - "mj_inpaint_pre": 0, + "mj_inpaint": 0, + "mj_custom_zoom": 0, "mj_describe": 0.05, "mj_upscale": 0.05, "swap_face": 0.05, diff --git a/constant/midjourney.go b/constant/midjourney.go index 92e2f23..f4ae2e4 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -13,8 +13,9 @@ const ( MjActionVariation = "VARIATION" MjActionReRoll = "REROLL" MjActionInPaint = "INPAINT" - MjActionInPaintPre = "INPAINT_PRE" + MjActionModal = "MODAL" MjActionZoom = "ZOOM" + MjActionCustomZoom = "CUSTOM_ZOOM" MjActionShorten = "SHORTEN" MjActionHighVariation = "HIGH_VARIATION" MjActionLowVariation = "LOW_VARIATION" @@ -29,9 +30,10 @@ var MidjourneyModel2Action = map[string]string{ "mj_upscale": MjActionUpscale, "mj_variation": MjActionVariation, "mj_reroll": MjActionReRoll, + "mj_modal": MjActionModal, "mj_inpaint": MjActionInPaint, - "mj_inpaint_pre": MjActionInPaintPre, "mj_zoom": MjActionZoom, + "mj_custom_zoom": MjActionCustomZoom, "mj_shorten": MjActionShorten, "mj_high_variation": MjActionHighVariation, "mj_low_variation": MjActionLowVariation, diff --git a/controller/relay.go b/controller/relay.go index fa5493a..e31679d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -67,6 +67,8 @@ func RelayMidjourney(c *gin.Context) { err = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: err = relay.RelayMidjourneyTask(c, relayMode) + case relayconstant.RelayModeMidjourneyTaskImageSeed: + err = relay.RelayMidjourneyTaskImageSeed(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } diff --git a/dto/midjourney.go b/dto/midjourney.go index f81756c..d3c3583 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -28,6 +28,11 @@ type MidjourneyResponse struct { Result string `json:"result"` } +type MidjourneyResponseWithStatusCode struct { + StatusCode int `json:"statusCode"` + Response MidjourneyResponse +} + type MidjourneyDto struct { MjId string `json:"id"` Action string `json:"action"` diff --git a/middleware/distributor.go b/middleware/distributor.go index c1b3ccb..ed457a3 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -48,7 +48,8 @@ func Distribute() func(c *gin.Context) { relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || - relayMode == relayconstant.RelayModeMidjourneyNotify { + relayMode == relayconstant.RelayModeMidjourneyNotify || + relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { shouldSelectChannel = false } else { midjourneyRequest := dto.MidjourneyRequest{} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 9f13726..197efdc 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -17,6 +17,7 @@ const ( RelayModeMidjourneySimpleChange RelayModeMidjourneyNotify RelayModeMidjourneyTaskFetch + RelayModeMidjourneyTaskImageSeed RelayModeMidjourneyTaskFetchByCondition RelayModeAudioSpeech RelayModeAudioTranscription @@ -77,6 +78,8 @@ func Path2RelayModeMidjourney(path string) int { relayMode = RelayModeMidjourneyChange } else if strings.HasSuffix(path, "/fetch") { relayMode = RelayModeMidjourneyTaskFetch + } else if strings.HasSuffix(path, "/image-seed") { + relayMode = RelayModeMidjourneyTaskImageSeed } else if strings.HasSuffix(path, "/list-by-condition") { relayMode = RelayModeMidjourneyTaskFetchByCondition } diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 7342717..8eebaeb 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -138,6 +138,31 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } +func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { + //taskId := c.Param("id") + //userId := c.GetInt("id") + //originTask := model.GetByMJId(userId, taskId) + //if originTask == nil { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found") + //} + //channel, err := model.GetChannelById(originTask.ChannelId, false) + //if err != nil { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed") + //} + //if channel.Status != common.ChannelStatusEnabled { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用") + //} + //c.Set("channel_id", originTask.ChannelId) + //requestURL := c.Request.URL.String() + //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL) + //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) + //if err != nil { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed") + //} + log.Println("RelayMidjourneyTaskImageSeed") + return nil +} + func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { userId := c.GetInt("id") var err error @@ -259,11 +284,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons mjId = params.TaskId midjRequest.Action = params.Action } else if relayMode == relayconstant.RelayModeMidjourneyModal { - if midjRequest.MaskBase64 == "" { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") - } + //if midjRequest.MaskBase64 == "" { + // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") + //} mjId = midjRequest.TaskId - midjRequest.Action = constant.MjActionInPaint + midjRequest.Action = constant.MjActionModal } originTask := model.GetByMJId(userId, mjId) @@ -293,29 +318,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //} } - if midjRequest.Action == constant.MjActionInPaintPre { + if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom { consumeQuota = false } - // map model name - //modelMapping := c.GetString("model_mapping") - //isModelMapped := false - //if modelMapping != "" { - // modelMap := make(map[string]string) - // err := json.Unmarshal([]byte(modelMapping), &modelMap) - // if err != nil { - // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - // return &dto.MidjourneyResponse{ - // Code: 4, - // Description: "unmarshal_model_mapping_failed", - // } - // } - // if modelMap[imageModel] != "" { - // imageModel = modelMap[imageModel] - // isModelMapped = true - // } - //} - //baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -325,20 +331,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - var requestBody io.Reader - //if isModelMapped { - // jsonStr, err := json.Marshal(midjRequest) - // if err != nil { - // return &dto.MidjourneyResponse{ - // Code: 4, - // Description: "marshal_text_request_failed", - // } - // } - // requestBody = bytes.NewBuffer(jsonStr) - //} else { - //} - requestBody = c.Request.Body - modelName := service.CoverActionToModelName(midjRequest.Action) modelPrice := common.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 @@ -368,40 +360,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest) if err != nil { - return &dto.MidjourneyResponse{ - Code: 4, - Description: "create_request_failed", - } + return &midjResponseWithStatus.Response } - //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey")) - timeout := time.Second * 30 - ctx, cancel := context.WithTimeout(context.Background(), timeout) - // 使用带有超时的 context 创建新的请求 - req = req.WithContext(ctx) - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) - // print request header - //log.Printf("request header: %s", req.Header) - //log.Printf("request body: %s", midjRequest.Prompt) - - defer cancel() - resp, err := service.GetHttpClient().Do(req) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed") - } - - err = req.Body.Close() - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed") - } - err = c.Request.Body.Close() - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed") - } - var midjResponse dto.MidjourneyResponse + midjResponse := &midjResponseWithStatus.Response defer func(ctx context.Context) { if consumeQuota { @@ -424,25 +387,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } }(c.Request.Context()) - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed") - } - err = resp.Body.Close() - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed") - } - if resp.StatusCode != 200 { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status") - } - err = json.Unmarshal(responseBody, &midjResponse) - log.Printf("responseBody: %s", string(responseBody)) - log.Printf("midjResponse: %v", midjResponse) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed") - } - // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md //1-提交成功 // 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}} @@ -494,7 +438,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } } //修改返回值 - if midjRequest.Action != constant.MjActionInPaintPre { + if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom { newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1) responseBody = []byte(newBody) } @@ -514,21 +458,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons responseBody = []byte(newBody) } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) + //for k, v := range resp.Header { + // c.Writer.Header().Set(k, v[0]) + //} + c.Writer.WriteHeader(midjResponseWithStatus.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) + _, err = io.Copy(c.Writer, bodyReader) if err != nil { return &dto.MidjourneyResponse{ Code: 4, Description: "copy_response_body_failed", } } - err = resp.Body.Close() + err = bodyReader.Close() if err != nil { return &dto.MidjourneyResponse{ Code: 4, diff --git a/router/relay-router.go b/router/relay-router.go index f572d8f..3c6910a 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -57,6 +57,7 @@ func SetRelayRouter(router *gin.Engine) { relayMjRouter.POST("/submit/blend", controller.RelayMidjourney) relayMjRouter.POST("/notify", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) + relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) } //relayMjRouter.Use() diff --git a/service/error.go b/service/error.go index 91c78c8..424be5d 100644 --- a/service/error.go +++ b/service/error.go @@ -18,6 +18,13 @@ func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse { } } +func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode { + return &dto.MidjourneyResponseWithStatusCode{ + StatusCode: statusCode, + Response: *MidjourneyErrorWrapper(code, desc), + } +} + // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { text := err.Error() diff --git a/service/midjourney.go b/service/midjourney.go index c04c4d3..17e54e1 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -1,11 +1,18 @@ package service import ( + "context" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "log" + "net/http" "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" "strconv" "strings" + "time" ) func CoverActionToModelName(mjAction string) string { @@ -35,7 +42,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin case relayconstant.RelayModeMidjourneyChange: action = midjRequest.Action case relayconstant.RelayModeMidjourneyModal: - action = constant.MjActionInPaint + action = constant.MjActionModal case relayconstant.RelayModeMidjourneySimpleChange: params := ConvertSimpleChangeParams(midjRequest.Content) if params == nil { @@ -96,11 +103,14 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj } else if strings.Contains(action, "reroll") { midjRequest.Action = constant.MjActionReRoll midjRequest.Index = 1 - } else if action == "Outpaint" || action == "CustomZoom" { + } else if action == "Outpaint" { midjRequest.Action = constant.MjActionZoom midjRequest.Index = 1 + } else if action == "CustomZoom" { + midjRequest.Action = constant.MjActionCustomZoom + midjRequest.Index = 1 } else if action == "Inpaint" { - midjRequest.Action = constant.MjActionInPaintPre + midjRequest.Action = constant.MjActionInPaint midjRequest.Index = 1 } else { return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId) @@ -136,3 +146,60 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest { changeParams.Index = index return changeParams } + +func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) { + var nullBytes []byte + var requestBody io.Reader + requestBody = c.Request.Body + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) + defer cancel() + resp, err := GetHttpClient().Do(req) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err + } + statusCode := resp.StatusCode + //if statusCode != 200 { + // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil + //} + err = req.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err + } + err = c.Request.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err + } + var midjResponse dto.MidjourneyResponse + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err + } + err = resp.Body.Close() + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err + } + + err = json.Unmarshal(responseBody, &midjResponse) + log.Printf("responseBody: %s", string(responseBody)) + if err != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + } + //log.Printf("midjResponse: %v", midjResponse) + //for k, v := range resp.Header { + // c.Writer.Header().Set(k, v[0]) + //} + return &dto.MidjourneyResponseWithStatusCode{ + StatusCode: statusCode, + Response: midjResponse, + }, responseBody, nil +} diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index 88da1cd..4843b2f 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -46,11 +46,13 @@ function renderType(type) { case 'REROLL': return 重绘; case 'INPAINT': - return 局部重绘; + return 局部重绘-提交; case 'ZOOM': return 变焦; - case 'INPAINT_PRE': - return 局部重绘-预处理; + case 'CUSTOM_ZOOM': + return 自定义变焦-提交; + case 'MODAL': + return 窗口处理; default: return 未知; } @@ -62,7 +64,7 @@ function renderCode(code) { case 1: return 已提交; case 21: - return 排队中; + return 等待中; case 22: return 重复提交; default: diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 757b56c..ccc18aa 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -109,8 +109,9 @@ const EditChannel = (props) => { 'mj_describe', 'mj_zoom', 'mj_shorten', - 'mj_inpaint_pre', + 'mj_modal', 'mj_inpaint', + 'mj_custom_zoom', 'mj_high_variation', 'mj_low_variation', 'mj_pan',