diff --git a/common/model-ratio.go b/common/model-ratio.go
index 153b748..3231d95 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -100,13 +100,14 @@ var DefaultModelPrice = map[string]float64{
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
- "mj_inpaint": 0.1,
+ "mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
- "mj_inpaint_pre": 0,
+ "mj_inpaint": 0,
+ "mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
diff --git a/constant/midjourney.go b/constant/midjourney.go
index 92e2f23..f4ae2e4 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -13,8 +13,9 @@ const (
MjActionVariation = "VARIATION"
MjActionReRoll = "REROLL"
MjActionInPaint = "INPAINT"
- MjActionInPaintPre = "INPAINT_PRE"
+ MjActionModal = "MODAL"
MjActionZoom = "ZOOM"
+ MjActionCustomZoom = "CUSTOM_ZOOM"
MjActionShorten = "SHORTEN"
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
@@ -29,9 +30,10 @@ var MidjourneyModel2Action = map[string]string{
"mj_upscale": MjActionUpscale,
"mj_variation": MjActionVariation,
"mj_reroll": MjActionReRoll,
+ "mj_modal": MjActionModal,
"mj_inpaint": MjActionInPaint,
- "mj_inpaint_pre": MjActionInPaintPre,
"mj_zoom": MjActionZoom,
+ "mj_custom_zoom": MjActionCustomZoom,
"mj_shorten": MjActionShorten,
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
diff --git a/controller/relay.go b/controller/relay.go
index fa5493a..e31679d 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -67,6 +67,8 @@ func RelayMidjourney(c *gin.Context) {
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
+ case relayconstant.RelayModeMidjourneyTaskImageSeed:
+ err = relay.RelayMidjourneyTaskImageSeed(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
diff --git a/dto/midjourney.go b/dto/midjourney.go
index f81756c..d3c3583 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -28,6 +28,11 @@ type MidjourneyResponse struct {
Result string `json:"result"`
}
+type MidjourneyResponseWithStatusCode struct {
+ StatusCode int `json:"statusCode"`
+ Response MidjourneyResponse
+}
+
type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
diff --git a/middleware/distributor.go b/middleware/distributor.go
index c1b3ccb..ed457a3 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -48,7 +48,8 @@ func Distribute() func(c *gin.Context) {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
- relayMode == relayconstant.RelayModeMidjourneyNotify {
+ relayMode == relayconstant.RelayModeMidjourneyNotify ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index 9f13726..197efdc 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -17,6 +17,7 @@ const (
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
+ RelayModeMidjourneyTaskImageSeed
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
@@ -77,6 +78,8 @@ func Path2RelayModeMidjourney(path string) int {
relayMode = RelayModeMidjourneyChange
} else if strings.HasSuffix(path, "/fetch") {
relayMode = RelayModeMidjourneyTaskFetch
+ } else if strings.HasSuffix(path, "/image-seed") {
+ relayMode = RelayModeMidjourneyTaskImageSeed
} else if strings.HasSuffix(path, "/list-by-condition") {
relayMode = RelayModeMidjourneyTaskFetchByCondition
}
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 7342717..8eebaeb 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -138,6 +138,31 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
return
}
+func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
+ //taskId := c.Param("id")
+ //userId := c.GetInt("id")
+ //originTask := model.GetByMJId(userId, taskId)
+ //if originTask == nil {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+ //}
+ //channel, err := model.GetChannelById(originTask.ChannelId, false)
+ //if err != nil {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+ //}
+ //if channel.Status != common.ChannelStatusEnabled {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+ //}
+ //c.Set("channel_id", originTask.ChannelId)
+ //requestURL := c.Request.URL.String()
+ //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
+ //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
+ //if err != nil {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
+ //}
+ log.Println("RelayMidjourneyTaskImageSeed")
+ return nil
+}
+
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
userId := c.GetInt("id")
var err error
@@ -259,11 +284,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
mjId = params.TaskId
midjRequest.Action = params.Action
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
- if midjRequest.MaskBase64 == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
- }
+ //if midjRequest.MaskBase64 == "" {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
+ //}
mjId = midjRequest.TaskId
- midjRequest.Action = constant.MjActionInPaint
+ midjRequest.Action = constant.MjActionModal
}
originTask := model.GetByMJId(userId, mjId)
@@ -293,29 +318,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//}
}
- if midjRequest.Action == constant.MjActionInPaintPre {
+ if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
consumeQuota = false
}
- // map model name
- //modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- //if modelMapping != "" {
- // modelMap := make(map[string]string)
- // err := json.Unmarshal([]byte(modelMapping), &modelMap)
- // if err != nil {
- // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- // return &dto.MidjourneyResponse{
- // Code: 4,
- // Description: "unmarshal_model_mapping_failed",
- // }
- // }
- // if modelMap[imageModel] != "" {
- // imageModel = modelMap[imageModel]
- // isModelMapped = true
- // }
- //}
-
//baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -325,20 +331,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- var requestBody io.Reader
- //if isModelMapped {
- // jsonStr, err := json.Marshal(midjRequest)
- // if err != nil {
- // return &dto.MidjourneyResponse{
- // Code: 4,
- // Description: "marshal_text_request_failed",
- // }
- // }
- // requestBody = bytes.NewBuffer(jsonStr)
- //} else {
- //}
- requestBody = c.Request.Body
-
modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
@@ -368,40 +360,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "create_request_failed",
- }
+ return &midjResponseWithStatus.Response
}
- //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
- timeout := time.Second * 30
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
- // print request header
- //log.Printf("request header: %s", req.Header)
- //log.Printf("request body: %s", midjRequest.Prompt)
-
- defer cancel()
- resp, err := service.GetHttpClient().Do(req)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed")
- }
-
- err = req.Body.Close()
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
- }
- err = c.Request.Body.Close()
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
- }
- var midjResponse dto.MidjourneyResponse
+ midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
if consumeQuota {
@@ -424,25 +387,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}(c.Request.Context())
- responseBody, err := io.ReadAll(resp.Body)
-
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed")
- }
- err = resp.Body.Close()
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed")
- }
- if resp.StatusCode != 200 {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status")
- }
- err = json.Unmarshal(responseBody, &midjResponse)
- log.Printf("responseBody: %s", string(responseBody))
- log.Printf("midjResponse: %v", midjResponse)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed")
- }
-
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
@@ -494,7 +438,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
//修改返回值
- if midjRequest.Action != constant.MjActionInPaintPre {
+ if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
@@ -514,21 +458,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody)
}
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
+ //for k, v := range resp.Header {
+ // c.Writer.Header().Set(k, v[0])
+ //}
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
+ _, err = io.Copy(c.Writer, bodyReader)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
- err = resp.Body.Close()
+ err = bodyReader.Close()
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
diff --git a/router/relay-router.go b/router/relay-router.go
index f572d8f..3c6910a 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -57,6 +57,7 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
+ relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
}
//relayMjRouter.Use()
diff --git a/service/error.go b/service/error.go
index 91c78c8..424be5d 100644
--- a/service/error.go
+++ b/service/error.go
@@ -18,6 +18,13 @@ func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
}
}
+func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
+ return &dto.MidjourneyResponseWithStatusCode{
+ StatusCode: statusCode,
+ Response: *MidjourneyErrorWrapper(code, desc),
+ }
+}
+
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
diff --git a/service/midjourney.go b/service/midjourney.go
index c04c4d3..17e54e1 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -1,11 +1,18 @@
package service
import (
+ "context"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "io"
+ "log"
+ "net/http"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"strconv"
"strings"
+ "time"
)
func CoverActionToModelName(mjAction string) string {
@@ -35,7 +42,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
case relayconstant.RelayModeMidjourneyChange:
action = midjRequest.Action
case relayconstant.RelayModeMidjourneyModal:
- action = constant.MjActionInPaint
+ action = constant.MjActionModal
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
@@ -96,11 +103,14 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
} else if strings.Contains(action, "reroll") {
midjRequest.Action = constant.MjActionReRoll
midjRequest.Index = 1
- } else if action == "Outpaint" || action == "CustomZoom" {
+ } else if action == "Outpaint" {
midjRequest.Action = constant.MjActionZoom
midjRequest.Index = 1
+ } else if action == "CustomZoom" {
+ midjRequest.Action = constant.MjActionCustomZoom
+ midjRequest.Index = 1
} else if action == "Inpaint" {
- midjRequest.Action = constant.MjActionInPaintPre
+ midjRequest.Action = constant.MjActionInPaint
midjRequest.Index = 1
} else {
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
@@ -136,3 +146,60 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
changeParams.Index = index
return changeParams
}
+
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+ var nullBytes []byte
+ var requestBody io.Reader
+ requestBody = c.Request.Body
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ // 使用带有超时的 context 创建新的请求
+ req = req.WithContext(ctx)
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
+ defer cancel()
+ resp, err := GetHttpClient().Do(req)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ statusCode := resp.StatusCode
+ //if statusCode != 200 {
+ // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
+ //}
+ err = req.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+ }
+ err = c.Request.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+ }
+ var midjResponse dto.MidjourneyResponse
+
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
+ }
+
+ err = json.Unmarshal(responseBody, &midjResponse)
+ log.Printf("responseBody: %s", string(responseBody))
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
+ }
+ //log.Printf("midjResponse: %v", midjResponse)
+ //for k, v := range resp.Header {
+ // c.Writer.Header().Set(k, v[0])
+ //}
+ return &dto.MidjourneyResponseWithStatusCode{
+ StatusCode: statusCode,
+ Response: midjResponse,
+ }, responseBody, nil
+}
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 88da1cd..4843b2f 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -46,11 +46,13 @@ function renderType(type) {
case 'REROLL':
return 重绘;
case 'INPAINT':
- return 局部重绘;
+ return 局部重绘-提交;
case 'ZOOM':
return 变焦;
- case 'INPAINT_PRE':
- return 局部重绘-预处理;
+ case 'CUSTOM_ZOOM':
+ return 自定义变焦-提交;
+ case 'MODAL':
+ return 窗口处理;
default:
return 未知;
}
@@ -62,7 +64,7 @@ function renderCode(code) {
case 1:
return 已提交;
case 21:
- return 排队中;
+ return 等待中;
case 22:
return 重复提交;
default:
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index 757b56c..ccc18aa 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -109,8 +109,9 @@ const EditChannel = (props) => {
'mj_describe',
'mj_zoom',
'mj_shorten',
- 'mj_inpaint_pre',
+ 'mj_modal',
'mj_inpaint',
+ 'mj_custom_zoom',
'mj_high_variation',
'mj_low_variation',
'mj_pan',