diff --git a/common/constants.go b/common/constants.go
index 98fa67a..cbb7861 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -189,7 +189,7 @@ const (
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
- ChannelTypeOpenAISB = 5
+ ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
diff --git a/constant/midjourney.go b/constant/midjourney.go
new file mode 100644
index 0000000..dbcc5c8
--- /dev/null
+++ b/constant/midjourney.go
@@ -0,0 +1,16 @@
+package constant
+
+const (
+ MjErrorUnknown = 5
+ MjRequestError = 4
+)
+
+const (
+ MjActionImagine = "IMAGINE"
+ MjActionDescribe = "DESCRIBE"
+ MjActionBlend = "BLEND"
+ MjActionUpscale = "UPSCALE"
+ MjActionVariation = "VARIATION"
+ MjActionInPaint = "INPAINT"
+ MjActionInPaintPre = "INPAINT_PRE"
+)
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 4bcd4d4..96f82ee 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
- case common.ChannelTypeOpenAISB:
- return updateChannelOpenAISBBalance(channel)
+ //case common.ChannelTypeOpenAISB:
+ // return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 1a42270..cac253c 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -10,8 +10,8 @@ import (
"log"
"net/http"
"one-api/common"
+ "one-api/dto"
"one-api/model"
- relay2 "one-api/relay"
"one-api/service"
"strconv"
"strings"
@@ -75,11 +75,11 @@ import (
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
- var responseItem Midjourney
+ var responseItem MidjourneyDto
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
- if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
+ if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") {
var responseWithoutStatus MidjourneyWithoutStatus
var responseStatus MidjourneyStatus
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
@@ -228,12 +228,16 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
+ if resp.StatusCode != http.StatusOK {
+ common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ continue
+ }
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
- var responseItems []relay2.Midjourney
+ var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
@@ -259,6 +263,10 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
+ if responseItem.Buttons != nil {
+ buttonStr, _ := json.Marshal(responseItem.Buttons)
+ task.Buttons = string(buttonStr)
+ }
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
@@ -286,7 +294,7 @@ func UpdateMidjourneyTaskBulk() {
}
}
-func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
+func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
if oldTask.Code != 1 {
return true
}
diff --git a/controller/relay.go b/controller/relay.go
index 911a7c5..a42db2e 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -62,7 +62,13 @@ func Relay(c *gin.Context) {
func RelayMidjourney(c *gin.Context) {
relayMode := relayconstant.RelayModeUnknown
- if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
+ if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/action") {
+ // midjourney plus
+ relayMode = relayconstant.RelayModeMidjourneyAction
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") {
+ // midjourney plus
+ relayMode = relayconstant.RelayModeMidjourneyModal
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = relayconstant.RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
relayMode = relayconstant.RelayModeMidjourneyBlend
@@ -86,35 +92,24 @@ func RelayMidjourney(c *gin.Context) {
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
+ //case relayconstant.RelayModeMidjourneyModal:
+ // err = relay.RelayMidjournneyModal(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
- retryTimesStr := c.Query("retry")
- retryTimes, _ := strconv.Atoi(retryTimesStr)
- if retryTimesStr == "" {
- retryTimes = common.RetryTimes
- }
- if retryTimes > 0 {
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
- } else {
- if err.Code == 30 {
- err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
- }
- c.JSON(429, gin.H{
- "error": fmt.Sprintf("%s %s", err.Description, err.Result),
- "type": "upstream_error",
- })
+ if err.Code == 30 {
+ err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
+ c.JSON(429, gin.H{
+ "error": fmt.Sprintf("%s %s", err.Description, err.Result),
+ "type": "upstream_error",
+ "code": err.Code,
+ })
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
- //if shouldDisableChannel(&err.Error) {
- // channelId := c.GetInt("channel_id")
- // channelName := c.GetString("channel_name")
- // disableChannel(channelId, channelName, err.Result)
- //};''''''''''''''''''''''''''''''''
}
}
diff --git a/dto/midjourney.go b/dto/midjourney.go
index 4c67909..a16a65e 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -2,6 +2,8 @@ package dto
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
+ CustomId string `json:"customId"`
+ BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
@@ -9,6 +11,7 @@ type MidjourneyRequest struct {
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
+ MaskBase64 string `json:"maskBase64"`
}
type MidjourneyResponse struct {
@@ -17,3 +20,52 @@ type MidjourneyResponse struct {
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
+
+type MidjourneyDto struct {
+ MjId string `json:"id"`
+ Action string `json:"action"`
+ CustomId string `json:"customId"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"promptEn"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submitTime"`
+ StartTime int64 `json:"startTime"`
+ FinishTime int64 `json:"finishTime"`
+ ImageUrl string `json:"imageUrl"`
+ Status string `json:"status"`
+ Progress string `json:"progress"`
+ FailReason string `json:"failReason"`
+ Buttons any `json:"buttons"`
+ MaskBase64 string `json:"maskBase64"`
+}
+
+type MidjourneyStatus struct {
+ Status int `json:"status"`
+}
+type MidjourneyWithoutStatus struct {
+ Id int `json:"id"`
+ Code int `json:"code"`
+ UserId int `json:"user_id" gorm:"index"`
+ Action string `json:"action"`
+ MjId string `json:"mj_id" gorm:"index"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"prompt_en"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submit_time"`
+ StartTime int64 `json:"start_time"`
+ FinishTime int64 `json:"finish_time"`
+ ImageUrl string `json:"image_url"`
+ Progress string `json:"progress"`
+ FailReason string `json:"fail_reason"`
+ ChannelId int `json:"channel_id"`
+}
+
+type ActionButton struct {
+ CustomId any `json:"customId"`
+ Emoji any `json:"emoji"`
+ Label any `json:"label"`
+ Type any `json:"type"`
+ Style any `json:"style"`
+}
diff --git a/middleware/auth.go b/middleware/auth.go
index ef774f6..a8dac30 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -125,12 +125,6 @@ func TokenAuth() func(c *gin.Context) {
} else {
c.Set("token_model_limit_enabled", false)
}
- requestURL := c.Request.URL.String()
- consumeQuota := true
- if strings.HasPrefix(requestURL, "/v1/models") {
- consumeQuota = false
- }
- c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
diff --git a/model/midjourney.go b/model/midjourney.go
index 0ef2e55..f20ab32 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -19,6 +19,7 @@ type Midjourney struct {
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
+ Buttons string `json:"buttons"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index beea7dc..c49caae 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -21,6 +21,8 @@ const (
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
+ RelayModeMidjourneyAction
+ RelayModeMidjourneyModal
)
func Path2RelayMode(path string) int {
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index b2b9926..f667cd1 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -9,6 +9,7 @@ import (
"log"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/model"
relayconstant "one-api/relay/constant"
@@ -20,51 +21,15 @@ import (
"github.com/gin-gonic/gin"
)
-type Midjourney struct {
- MjId string `json:"id"`
- Action string `json:"action"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"promptEn"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submitTime"`
- StartTime int64 `json:"startTime"`
- FinishTime int64 `json:"finishTime"`
- ImageUrl string `json:"imageUrl"`
- Status string `json:"status"`
- Progress string `json:"progress"`
- FailReason string `json:"failReason"`
-}
-
-type MidjourneyStatus struct {
- Status int `json:"status"`
-}
-type MidjourneyWithoutStatus struct {
- Id int `json:"id"`
- Code int `json:"code"`
- UserId int `json:"user_id" gorm:"index"`
- Action string `json:"action"`
- MjId string `json:"mj_id" gorm:"index"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"prompt_en"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submit_time"`
- StartTime int64 `json:"start_time"`
- FinishTime int64 `json:"finish_time"`
- ImageUrl string `json:"image_url"`
- Progress string `json:"progress"`
- FailReason string `json:"fail_reason"`
- ChannelId int `json:"channel_id"`
-}
-
var DefaultModelPrice = map[string]float64{
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_inpaint": 0.1,
+ "mj_inpaint_pre": 0,
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
}
func RelayMidjourneyImage(c *gin.Context) {
@@ -108,7 +73,7 @@ func RelayMidjourneyImage(c *gin.Context) {
}
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
- var midjRequest Midjourney
+ var midjRequest dto.MidjourneyDto
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &dto.MidjourneyResponse{
@@ -147,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
return nil
}
-func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
+func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
@@ -167,9 +132,41 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
+ if originTask.Buttons != "" {
+ var buttons []dto.ActionButton
+ err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
+ if err == nil {
+ midjourneyTask.Buttons = buttons
+ }
+ }
return
}
+func RelayMidjournneyModal(c *gin.Context) *dto.MidjourneyResponse {
+ userId := c.GetInt("id")
+ var midjRequest dto.MidjourneyRequest
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+ originTask := model.GetByMJId(userId, midjRequest.TaskId)
+ if originTask == nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+ }
+
+ respBody, err := json.Marshal(midjRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+ }
+ return nil
+
+}
+
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
userId := c.GetInt("id")
var err error
@@ -184,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "task_no_found",
}
}
- midjourneyTask := getMidjourneyTaskModel(c, originTask)
+ midjourneyTask := getMidjourneyTaskDto(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &dto.MidjourneyResponse{
@@ -203,16 +200,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "do_request_failed",
}
}
- var tasks []Midjourney
+ var tasks []dto.MidjourneyDto
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
- midjourneyTask := getMidjourneyTaskModel(c, originTask)
+ midjourneyTask := getMidjourneyTaskDto(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
- tasks = make([]Midjourney, 0)
+ tasks = make([]dto.MidjourneyDto, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
@@ -235,44 +232,32 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
return nil
}
-const (
- // type 1 根据 mode 价格不同
- MJSubmitActionImagine = "IMAGINE"
- MJSubmitActionVariation = "VARIATION" //变换
- MJSubmitActionBlend = "BLEND" //混图
-
- MJSubmitActionReroll = "REROLL" //重新生成
- // type 2 固定价格
- MJSubmitActionDescribe = "DESCRIBE"
- MJSubmitActionUpscale = "UPSCALE" // 放大
-)
-
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
imageModel := "midjourney"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
- consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
+ consumeQuota := true
var midjRequest dto.MidjourneyRequest
- if consumeQuota {
- err := common.UnmarshalBodyReusable(c, &midjRequest)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "bind_request_body_failed",
- }
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+
+ if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
+ mjErr := coverPlusActionToNormalAction(&midjRequest)
+ if mjErr != nil {
+ return mjErr
}
+ relayMode = relayconstant.RelayModeMidjourneyChange
}
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "prompt_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
}
midjRequest.Action = "IMAGINE"
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
@@ -283,71 +268,58 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "taskId_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "action_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
} else if midjRequest.Index == 0 {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "index_can_only_be_1_2_3_4",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "content_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
}
params := convertSimpleChangeParams(midjRequest.Content)
if params == nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "content_parse_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
}
mjId = params.ID
midjRequest.Action = params.Action
+ } else if relayMode == relayconstant.RelayModeMidjourneyModal {
+ if midjRequest.MaskBase64 == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
+ }
+ mjId = midjRequest.TaskId
+ midjRequest.Action = "INPAINT"
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "task_no_found",
- }
- } else if originTask.Action == "UPSCALE" {
- //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "upscale_task_can_not_be_change",
- }
- } else if originTask.Status != "SUCCESS" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "task_status_is_not_success",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
+ } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "channel_not_found",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
- log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
+ log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
+
+ if channelType == common.ChannelTypeMidjourneyPlus {
+ // plus
+ } else {
+ // 普通版渠道
+
+ }
+ }
+
+ if midjRequest.Action == constant.MjActionInPaintPre {
+ consumeQuota = false
}
// map model name
@@ -379,7 +351,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- log.Printf("fullRequestURL: %s", fullRequestURL)
var requestBody io.Reader
if isModelMapped {
@@ -394,6 +365,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} else {
requestBody = c.Request.Body
}
+
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
modelPrice := common.GetModelPrice(mjAction, true)
// 如果没有配置价格,则使用默认价格
@@ -489,9 +461,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}(c.Request.Context())
- //if consumeQuota {
- //
- //}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -651,3 +620,43 @@ func convertSimpleChangeParams(content string) *taskChangeParams {
changeParams.Index = index
return changeParams
}
+
+func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
+ // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
+ customId := midjRequest.CustomId
+ if customId == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
+ }
+ splits := strings.Split(customId, "::")
+ var action string
+ if splits[1] == "JOB" {
+ action = splits[2]
+ } else {
+ action = splits[1]
+ }
+
+ if action == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ }
+ if strings.Contains(action, "upsample") {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionUpscale
+ } else if strings.Contains(action, "variation") {
+ midjRequest.Action = constant.MjActionVariation
+ } else if strings.Contains(action, "pan") {
+ midjRequest.Action = constant.MjActionVariation
+ midjRequest.Index = 1
+ } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") {
+ midjRequest.Action = constant.MjActionInPaintPre
+ } else if action == "Inpaint" {
+ midjRequest.Action = constant.MjActionInPaintPre
+ midjRequest.Index = 1
+ } else {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ }
+ return nil
+}
diff --git a/router/relay-router.go b/router/relay-router.go
index 6a30a5a..68b762b 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -47,6 +47,8 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
+ relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
diff --git a/service/error.go b/service/error.go
index 303bcf7..91c78c8 100644
--- a/service/error.go
+++ b/service/error.go
@@ -11,6 +11,13 @@ import (
"strings"
)
+func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
+ return &dto.MidjourneyResponse{
+ Code: code,
+ Description: desc,
+ }
+}
+
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 1f71208..4f17c14 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -35,6 +35,10 @@ function renderType(type) {
return 图生文;
case 'BLEAND':
return 图混合;
+ case 'INPAINT':
+ return 局部重绘;
+ case 'INPAINT_PRE':
+ return 局部重绘-预处理;
default:
return 未知;
}
@@ -68,6 +72,8 @@ function renderStatus(type) {
return 执行中;
case 'FAILURE':
return 失败;
+ case 'MODAL':
+ return 窗口等待;
default:
return 未知;
}
diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js
index a641a02..bb18d0d 100644
--- a/web/src/constants/channel.constants.js
+++ b/web/src/constants/channel.constants.js
@@ -1,6 +1,7 @@
export const CHANNEL_OPTIONS = [
{key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
{key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
+ {key: 5, text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus'},
{key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'},