From 37c0c8ebdd252e5c9dcd205200e416e758dd8dc2 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 15:37:01 +0800
Subject: [PATCH 01/16] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC?=
=?UTF-8?q?=E5=AE=B9midjourney-proxy-plus?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
common/constants.go | 2 +-
constant/midjourney.go | 16 ++
controller/channel-billing.go | 4 +-
controller/midjourney.go | 18 +-
controller/relay.go | 37 ++--
dto/midjourney.go | 52 ++++++
middleware/auth.go | 6 -
model/midjourney.go | 1 +
relay/constant/relay_mode.go | 2 +
relay/relay-mj.go | 245 +++++++++++++------------
router/relay-router.go | 2 +
service/error.go | 7 +
web/src/components/MjLogsTable.js | 6 +
web/src/constants/channel.constants.js | 1 +
14 files changed, 246 insertions(+), 153 deletions(-)
create mode 100644 constant/midjourney.go
diff --git a/common/constants.go b/common/constants.go
index 98fa67a..cbb7861 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -189,7 +189,7 @@ const (
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
- ChannelTypeOpenAISB = 5
+ ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
diff --git a/constant/midjourney.go b/constant/midjourney.go
new file mode 100644
index 0000000..dbcc5c8
--- /dev/null
+++ b/constant/midjourney.go
@@ -0,0 +1,16 @@
+package constant
+
+const (
+ MjErrorUnknown = 5
+ MjRequestError = 4
+)
+
+const (
+ MjActionImagine = "IMAGINE"
+ MjActionDescribe = "DESCRIBE"
+ MjActionBlend = "BLEND"
+ MjActionUpscale = "UPSCALE"
+ MjActionVariation = "VARIATION"
+ MjActionInPaint = "INPAINT"
+ MjActionInPaintPre = "INPAINT_PRE"
+)
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 4bcd4d4..96f82ee 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
- case common.ChannelTypeOpenAISB:
- return updateChannelOpenAISBBalance(channel)
+ //case common.ChannelTypeOpenAISB:
+ // return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 1a42270..cac253c 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -10,8 +10,8 @@ import (
"log"
"net/http"
"one-api/common"
+ "one-api/dto"
"one-api/model"
- relay2 "one-api/relay"
"one-api/service"
"strconv"
"strings"
@@ -75,11 +75,11 @@ import (
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
- var responseItem Midjourney
+ var responseItem MidjourneyDto
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
- if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
+ if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") {
var responseWithoutStatus MidjourneyWithoutStatus
var responseStatus MidjourneyStatus
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
@@ -228,12 +228,16 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
+ if resp.StatusCode != http.StatusOK {
+ common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ continue
+ }
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
- var responseItems []relay2.Midjourney
+ var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
@@ -259,6 +263,10 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
+ if responseItem.Buttons != nil {
+ buttonStr, _ := json.Marshal(responseItem.Buttons)
+ task.Buttons = string(buttonStr)
+ }
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
@@ -286,7 +294,7 @@ func UpdateMidjourneyTaskBulk() {
}
}
-func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
+func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
if oldTask.Code != 1 {
return true
}
diff --git a/controller/relay.go b/controller/relay.go
index 911a7c5..a42db2e 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -62,7 +62,13 @@ func Relay(c *gin.Context) {
func RelayMidjourney(c *gin.Context) {
relayMode := relayconstant.RelayModeUnknown
- if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
+ if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/action") {
+ // midjourney plus
+ relayMode = relayconstant.RelayModeMidjourneyAction
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") {
+ // midjourney plus
+ relayMode = relayconstant.RelayModeMidjourneyModal
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = relayconstant.RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
relayMode = relayconstant.RelayModeMidjourneyBlend
@@ -86,35 +92,24 @@ func RelayMidjourney(c *gin.Context) {
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
+ //case relayconstant.RelayModeMidjourneyModal:
+ // err = relay.RelayMidjournneyModal(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
- retryTimesStr := c.Query("retry")
- retryTimes, _ := strconv.Atoi(retryTimesStr)
- if retryTimesStr == "" {
- retryTimes = common.RetryTimes
- }
- if retryTimes > 0 {
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
- } else {
- if err.Code == 30 {
- err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
- }
- c.JSON(429, gin.H{
- "error": fmt.Sprintf("%s %s", err.Description, err.Result),
- "type": "upstream_error",
- })
+ if err.Code == 30 {
+ err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
+ c.JSON(429, gin.H{
+ "error": fmt.Sprintf("%s %s", err.Description, err.Result),
+ "type": "upstream_error",
+ "code": err.Code,
+ })
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
- //if shouldDisableChannel(&err.Error) {
- // channelId := c.GetInt("channel_id")
- // channelName := c.GetString("channel_name")
- // disableChannel(channelId, channelName, err.Result)
- //};''''''''''''''''''''''''''''''''
}
}
diff --git a/dto/midjourney.go b/dto/midjourney.go
index 4c67909..a16a65e 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -2,6 +2,8 @@ package dto
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
+ CustomId string `json:"customId"`
+ BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
@@ -9,6 +11,7 @@ type MidjourneyRequest struct {
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
+ MaskBase64 string `json:"maskBase64"`
}
type MidjourneyResponse struct {
@@ -17,3 +20,52 @@ type MidjourneyResponse struct {
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
+
+type MidjourneyDto struct {
+ MjId string `json:"id"`
+ Action string `json:"action"`
+ CustomId string `json:"customId"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"promptEn"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submitTime"`
+ StartTime int64 `json:"startTime"`
+ FinishTime int64 `json:"finishTime"`
+ ImageUrl string `json:"imageUrl"`
+ Status string `json:"status"`
+ Progress string `json:"progress"`
+ FailReason string `json:"failReason"`
+ Buttons any `json:"buttons"`
+ MaskBase64 string `json:"maskBase64"`
+}
+
+type MidjourneyStatus struct {
+ Status int `json:"status"`
+}
+type MidjourneyWithoutStatus struct {
+ Id int `json:"id"`
+ Code int `json:"code"`
+ UserId int `json:"user_id" gorm:"index"`
+ Action string `json:"action"`
+ MjId string `json:"mj_id" gorm:"index"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"prompt_en"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submit_time"`
+ StartTime int64 `json:"start_time"`
+ FinishTime int64 `json:"finish_time"`
+ ImageUrl string `json:"image_url"`
+ Progress string `json:"progress"`
+ FailReason string `json:"fail_reason"`
+ ChannelId int `json:"channel_id"`
+}
+
+type ActionButton struct {
+ CustomId any `json:"customId"`
+ Emoji any `json:"emoji"`
+ Label any `json:"label"`
+ Type any `json:"type"`
+ Style any `json:"style"`
+}
diff --git a/middleware/auth.go b/middleware/auth.go
index ef774f6..a8dac30 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -125,12 +125,6 @@ func TokenAuth() func(c *gin.Context) {
} else {
c.Set("token_model_limit_enabled", false)
}
- requestURL := c.Request.URL.String()
- consumeQuota := true
- if strings.HasPrefix(requestURL, "/v1/models") {
- consumeQuota = false
- }
- c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
diff --git a/model/midjourney.go b/model/midjourney.go
index 0ef2e55..f20ab32 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -19,6 +19,7 @@ type Midjourney struct {
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
+ Buttons string `json:"buttons"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index beea7dc..c49caae 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -21,6 +21,8 @@ const (
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
+ RelayModeMidjourneyAction
+ RelayModeMidjourneyModal
)
func Path2RelayMode(path string) int {
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index b2b9926..f667cd1 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -9,6 +9,7 @@ import (
"log"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/model"
relayconstant "one-api/relay/constant"
@@ -20,51 +21,15 @@ import (
"github.com/gin-gonic/gin"
)
-type Midjourney struct {
- MjId string `json:"id"`
- Action string `json:"action"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"promptEn"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submitTime"`
- StartTime int64 `json:"startTime"`
- FinishTime int64 `json:"finishTime"`
- ImageUrl string `json:"imageUrl"`
- Status string `json:"status"`
- Progress string `json:"progress"`
- FailReason string `json:"failReason"`
-}
-
-type MidjourneyStatus struct {
- Status int `json:"status"`
-}
-type MidjourneyWithoutStatus struct {
- Id int `json:"id"`
- Code int `json:"code"`
- UserId int `json:"user_id" gorm:"index"`
- Action string `json:"action"`
- MjId string `json:"mj_id" gorm:"index"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"prompt_en"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submit_time"`
- StartTime int64 `json:"start_time"`
- FinishTime int64 `json:"finish_time"`
- ImageUrl string `json:"image_url"`
- Progress string `json:"progress"`
- FailReason string `json:"fail_reason"`
- ChannelId int `json:"channel_id"`
-}
-
var DefaultModelPrice = map[string]float64{
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_inpaint": 0.1,
+ "mj_inpaint_pre": 0,
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
}
func RelayMidjourneyImage(c *gin.Context) {
@@ -108,7 +73,7 @@ func RelayMidjourneyImage(c *gin.Context) {
}
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
- var midjRequest Midjourney
+ var midjRequest dto.MidjourneyDto
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &dto.MidjourneyResponse{
@@ -147,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
return nil
}
-func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
+func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
@@ -167,9 +132,41 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
+ if originTask.Buttons != "" {
+ var buttons []dto.ActionButton
+ err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
+ if err == nil {
+ midjourneyTask.Buttons = buttons
+ }
+ }
return
}
+func RelayMidjournneyModal(c *gin.Context) *dto.MidjourneyResponse {
+ userId := c.GetInt("id")
+ var midjRequest dto.MidjourneyRequest
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+ originTask := model.GetByMJId(userId, midjRequest.TaskId)
+ if originTask == nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+ }
+
+ respBody, err := json.Marshal(midjRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+ }
+ return nil
+
+}
+
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
userId := c.GetInt("id")
var err error
@@ -184,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "task_no_found",
}
}
- midjourneyTask := getMidjourneyTaskModel(c, originTask)
+ midjourneyTask := getMidjourneyTaskDto(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &dto.MidjourneyResponse{
@@ -203,16 +200,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "do_request_failed",
}
}
- var tasks []Midjourney
+ var tasks []dto.MidjourneyDto
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
- midjourneyTask := getMidjourneyTaskModel(c, originTask)
+ midjourneyTask := getMidjourneyTaskDto(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
- tasks = make([]Midjourney, 0)
+ tasks = make([]dto.MidjourneyDto, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
@@ -235,44 +232,32 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
return nil
}
-const (
- // type 1 根据 mode 价格不同
- MJSubmitActionImagine = "IMAGINE"
- MJSubmitActionVariation = "VARIATION" //变换
- MJSubmitActionBlend = "BLEND" //混图
-
- MJSubmitActionReroll = "REROLL" //重新生成
- // type 2 固定价格
- MJSubmitActionDescribe = "DESCRIBE"
- MJSubmitActionUpscale = "UPSCALE" // 放大
-)
-
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
imageModel := "midjourney"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
- consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
+ consumeQuota := true
var midjRequest dto.MidjourneyRequest
- if consumeQuota {
- err := common.UnmarshalBodyReusable(c, &midjRequest)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "bind_request_body_failed",
- }
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+
+ if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
+ mjErr := coverPlusActionToNormalAction(&midjRequest)
+ if mjErr != nil {
+ return mjErr
}
+ relayMode = relayconstant.RelayModeMidjourneyChange
}
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "prompt_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
}
midjRequest.Action = "IMAGINE"
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
@@ -283,71 +268,58 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "taskId_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "action_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
} else if midjRequest.Index == 0 {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "index_can_only_be_1_2_3_4",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "content_is_required",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
}
params := convertSimpleChangeParams(midjRequest.Content)
if params == nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "content_parse_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
}
mjId = params.ID
midjRequest.Action = params.Action
+ } else if relayMode == relayconstant.RelayModeMidjourneyModal {
+ if midjRequest.MaskBase64 == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
+ }
+ mjId = midjRequest.TaskId
+ midjRequest.Action = "INPAINT"
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "task_no_found",
- }
- } else if originTask.Action == "UPSCALE" {
- //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "upscale_task_can_not_be_change",
- }
- } else if originTask.Status != "SUCCESS" {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "task_status_is_not_success",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
+ } else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "channel_not_found",
- }
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
- log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
+ log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
+
+ if channelType == common.ChannelTypeMidjourneyPlus {
+ // plus
+ } else {
+ // 普通版渠道
+
+ }
+ }
+
+ if midjRequest.Action == constant.MjActionInPaintPre {
+ consumeQuota = false
}
// map model name
@@ -379,7 +351,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- log.Printf("fullRequestURL: %s", fullRequestURL)
var requestBody io.Reader
if isModelMapped {
@@ -394,6 +365,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} else {
requestBody = c.Request.Body
}
+
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
modelPrice := common.GetModelPrice(mjAction, true)
// 如果没有配置价格,则使用默认价格
@@ -489,9 +461,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}(c.Request.Context())
- //if consumeQuota {
- //
- //}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -651,3 +620,43 @@ func convertSimpleChangeParams(content string) *taskChangeParams {
changeParams.Index = index
return changeParams
}
+
+func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
+ // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
+ customId := midjRequest.CustomId
+ if customId == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
+ }
+ splits := strings.Split(customId, "::")
+ var action string
+ if splits[1] == "JOB" {
+ action = splits[2]
+ } else {
+ action = splits[1]
+ }
+
+ if action == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ }
+ if strings.Contains(action, "upsample") {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionUpscale
+ } else if strings.Contains(action, "variation") {
+ midjRequest.Action = constant.MjActionVariation
+ } else if strings.Contains(action, "pan") {
+ midjRequest.Action = constant.MjActionVariation
+ midjRequest.Index = 1
+ } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") {
+ midjRequest.Action = constant.MjActionInPaintPre
+ } else if action == "Inpaint" {
+ midjRequest.Action = constant.MjActionInPaintPre
+ midjRequest.Index = 1
+ } else {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ }
+ return nil
+}
diff --git a/router/relay-router.go b/router/relay-router.go
index 6a30a5a..68b762b 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -47,6 +47,8 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
+ relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
diff --git a/service/error.go b/service/error.go
index 303bcf7..91c78c8 100644
--- a/service/error.go
+++ b/service/error.go
@@ -11,6 +11,13 @@ import (
"strings"
)
+func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
+ return &dto.MidjourneyResponse{
+ Code: code,
+ Description: desc,
+ }
+}
+
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 1f71208..4f17c14 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -35,6 +35,10 @@ function renderType(type) {
return 图生文;
case 'BLEAND':
return 图混合;
+ case 'INPAINT':
+ return 局部重绘;
+ case 'INPAINT_PRE':
+ return 局部重绘-预处理;
default:
return 未知;
}
@@ -68,6 +72,8 @@ function renderStatus(type) {
return 执行中;
case 'FAILURE':
return 失败;
+ case 'MODAL':
+ return 窗口等待;
default:
return 未知;
}
diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js
index a641a02..bb18d0d 100644
--- a/web/src/constants/channel.constants.js
+++ b/web/src/constants/channel.constants.js
@@ -1,6 +1,7 @@
export const CHANNEL_OPTIONS = [
{key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
{key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
+ {key: 5, text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus'},
{key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'},
From fd3a41bacb326f7d1b15b9b447efafced679d419 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 16:19:22 +0800
Subject: [PATCH 02/16] =?UTF-8?q?feat:=20=E8=AF=B7=E6=B1=82=E8=B6=85?=
=?UTF-8?q?=E6=97=B6=E5=A4=84=E7=90=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
dto/midjourney.go | 1 +
relay/relay-mj.go | 27 +++++++++++++--------------
web/src/components/MjLogsTable.js | 2 ++
web/src/pages/Channel/EditChannel.js | 6 ++++++
4 files changed, 22 insertions(+), 14 deletions(-)
diff --git a/dto/midjourney.go b/dto/midjourney.go
index a16a65e..4fef4e1 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -25,6 +25,7 @@ type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
CustomId string `json:"customId"`
+ BotType string `json:"botType"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index f667cd1..5fafc89 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -112,7 +112,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
return nil
}
-func getMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
+func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
@@ -181,7 +181,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "task_no_found",
}
}
- midjourneyTask := getMidjourneyTaskDto(c, originTask)
+ midjourneyTask := coverMidjourneyTaskDto(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &dto.MidjourneyResponse{
@@ -204,7 +204,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
- midjourneyTask := getMidjourneyTaskDto(c, originTask)
+ midjourneyTask := coverMidjourneyTaskDto(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
@@ -403,23 +403,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
-
+ timeout := time.Second * 30
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ // 使用带有超时的 context 创建新的请求
+ req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- //mjToken := ""
- //if c.Request.Header.Get("ApiKey") != "" {
- // mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
- //}
- //req.Header.Set("ApiKey", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header
- log.Printf("request header: %s", req.Header)
- log.Printf("request body: %s", midjRequest.Prompt)
+ //log.Printf("request header: %s", req.Header)
+ //log.Printf("request body: %s", midjRequest.Prompt)
+ defer cancel()
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return &dto.MidjourneyResponse{
- Code: 4,
+ Code: 5,
Description: "do_request_failed",
}
}
@@ -427,14 +426,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
err = req.Body.Close()
if err != nil {
return &dto.MidjourneyResponse{
- Code: 4,
+ Code: 5,
Description: "close_request_body_failed",
}
}
err = c.Request.Body.Close()
if err != nil {
return &dto.MidjourneyResponse{
- Code: 4,
+ Code: 5,
Description: "close_request_body_failed",
}
}
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 4f17c14..4accf54 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -35,6 +35,8 @@ function renderType(type) {
return 图生文;
case 'BLEAND':
return 图混合;
+ case 'REROLL':
+ return 重绘;
case 'INPAINT':
return 局部重绘;
case 'INPAINT_PRE':
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index 7221b9a..ee79368 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -95,6 +95,12 @@ const EditChannel = (props) => {
case 26:
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
break;
+ case 2:
+ localModels = ['midjourney'];
+ break;
+ case 5:
+ localModels = ['midjourney'];
+ break;
}
setInputs((inputs) => ({...inputs, models: localModels}));
}
From 728dbed28d5738d1d3b4a4925f719cb05026af1f Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 16:29:27 +0800
Subject: [PATCH 03/16] =?UTF-8?q?feat:=20=E5=85=BC=E5=AE=B9=E5=8F=98?=
=?UTF-8?q?=E7=84=A6=E5=8A=9F=E8=83=BD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
constant/midjourney.go | 1 +
relay/relay-mj.go | 5 ++++-
web/src/components/MjLogsTable.js | 2 ++
3 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/constant/midjourney.go b/constant/midjourney.go
index dbcc5c8..c184435 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -13,4 +13,5 @@ const (
MjActionVariation = "VARIATION"
MjActionInPaint = "INPAINT"
MjActionInPaintPre = "INPAINT_PRE"
+ MjActionZoom = "ZOOM"
)
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 5fafc89..d582055 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -27,6 +27,7 @@ var DefaultModelPrice = map[string]float64{
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_inpaint": 0.1,
+ "mj_zoom": 0.1,
"mj_inpaint_pre": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
@@ -646,11 +647,13 @@ func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
midjRequest.Action = constant.MjActionUpscale
} else if strings.Contains(action, "variation") {
midjRequest.Action = constant.MjActionVariation
+ midjRequest.Index = 1
} else if strings.Contains(action, "pan") {
midjRequest.Action = constant.MjActionVariation
midjRequest.Index = 1
} else if action == "Outpaint" || strings.Contains(action, "CustomZoom") {
- midjRequest.Action = constant.MjActionInPaintPre
+ midjRequest.Action = constant.MjActionZoom
+ midjRequest.Index = 1
} else if action == "Inpaint" {
midjRequest.Action = constant.MjActionInPaintPre
midjRequest.Index = 1
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 4accf54..a1ffeb6 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -39,6 +39,8 @@ function renderType(type) {
return 重绘;
case 'INPAINT':
return 局部重绘;
+ case 'ZOOM':
+ return 变焦;
case 'INPAINT_PRE':
return 局部重绘-预处理;
default:
From 2ad591411eff3f7f1ef91cc012e25ee915dea550 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 17:46:34 +0800
Subject: [PATCH 04/16] feat: support shorten
---
Midjourney.md | 287 ++----------------------------
constant/midjourney.go | 1 +
controller/midjourney.go | 4 +
controller/relay.go | 3 +
dto/midjourney.go | 40 +++--
model/midjourney.go | 1 +
relay/constant/relay_mode.go | 1 +
relay/relay-mj.go | 58 +++---
router/relay-router.go | 1 +
web/src/components/MjLogsTable.js | 2 +
10 files changed, 74 insertions(+), 324 deletions(-)
diff --git a/Midjourney.md b/Midjourney.md
index fe4d433..becc9c9 100644
--- a/Midjourney.md
+++ b/Midjourney.md
@@ -7,285 +7,28 @@
```json
{
"gpt-4-gizmo-*": 0.1,
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_describe": 0.05,
- "mj_upscale": 0.05
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_inpaint": 0.1,
+ "mj_zoom": 0.1,
+ "mj_inpaint_pre": 0,
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
+ "swap_face": 0.05
}
```
## 渠道设置
-### 对接 midjourney-proxy
+### 对接 midjourney-proxy(plus)
1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
-2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
+2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
### 对接上游new api
-1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
-2. 地址填写上游new api的地址,例如:http://localhost:8080
-3. 密钥填写上游new api的密钥
-
-## 任务提交
-
-### 绘图变化
-
-**接口地址**:`/mj/submit/change`
-
-**请求方式**:`POST`
-
-**请求数据类型**:`application/json`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求示例**:
-
-```javascript
-{
- "action"
-:
- "UPSCALE",
- "index"
-:
- 1,
- "notifyHook"
-:
- "",
- "state"
-:
- "",
- "taskId"
-:
- "1320098173412546"
-}
-```
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
-| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
-| action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
-| index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
-| notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
-| state | 自定义参数 | | false | string | |
-| taskId | 任务ID | | true | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 提交结果 |
-| 201 | Created | |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|-------------------------------------------|----------------|----------------|
-| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
-| description | 描述 | string | |
-| properties | 扩展字段 | object | |
-| result | 任务ID | string | |
-
-**响应示例**:
-
-```javascript
-{
- "code"
-:
- 1,
- "description"
-:
- "提交成功",
- "properties"
-:
- {
- }
-,
- "result"
-:
- 1320098173412546
-}
-```
-
-### 提交Imagine任务
-
-**接口地址**:`/mj/submit/imagine`
-
-**请求方式**:`POST`
-
-**请求数据类型**:`application/json`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求示例**:
-
-```javascript
-{
- "base64"
-:
- "",
- "notifyHook"
-:
- "",
- "prompt"
-:
- "Cat",
- "state"
-:
- ""
-}
-```
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------------------------|-------------------------|------|-------|-------------|-------------|
-| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
-| base64 | 垫图base64 | | false | string | |
-| notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
-| prompt | 提示词 | | true | string | |
-| state | 自定义参数 | | false | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 提交结果 |
-| 201 | Created | |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|-------------------------------------------|----------------|----------------|
-| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
-| description | 描述 | string | |
-| properties | 扩展字段 | object | |
-| result | 任务ID | string | |
-
-**响应示例**:
-
-```javascript
-{
- "code"
-:
- 1,
- "description"
-:
- "提交成功",
- "properties"
-:
- {
- }
-,
- "result"
-:
- 1320098173412546
-}
-```
-
-## 任务查询
-
-### 指定ID获取任务
-
-**接口地址**:`/mj/task/{id}/fetch`
-
-**请求方式**:`GET`
-
-**请求数据类型**:`application/x-www-form-urlencoded`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
-|------|------|------|-------|--------|--------|
-| id | 任务ID | path | false | string | |
-
-**响应状态**:
-
-| 状态码 | 说明 | schema |
-|-----|--------------|--------|
-| 200 | OK | 任务 |
-| 401 | Unauthorized | |
-| 403 | Forbidden | |
-| 404 | Not Found | |
-
-**响应参数**:
-
-| 参数名称 | 参数说明 | 类型 | schema |
-|-------------|----------------------------------------------------------|----------------|----------------|
-| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
-| description | 任务描述 | string | |
-| failReason | 失败原因 | string | |
-| finishTime | 结束时间 | integer(int64) | integer(int64) |
-| id | 任务ID | string | |
-| imageUrl | 图片url | string | |
-| progress | 任务进度 | string | |
-| prompt | 提示词 | string | |
-| promptEn | 提示词-英文 | string | |
-| startTime | 开始执行时间 | integer(int64) | integer(int64) |
-| state | 自定义参数 | string | |
-| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
-| submitTime | 提交时间 | integer(int64) | integer(int64) |
-
-**响应示例**:
-
-```javascript
-{
- "action"
-:
- "",
- "description"
-:
- "",
- "failReason"
-:
- "",
- "finishTime"
-:
- 0,
- "id"
-:
- "",
- "imageUrl"
-:
- "",
- "progress"
-:
- "",
- "prompt"
-:
- "",
- "promptEn"
-:
- "",
- "startTime"
-:
- 0,
- "state"
-:
- "",
- "status"
-:
- "",
- "submitTime"
-:
- 0
-}
-```
\ No newline at end of file
+1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
+2. 地址填写上游new api的地址,例如:http://localhost:3000
+3. 密钥填写上游new api的密钥
\ No newline at end of file
diff --git a/constant/midjourney.go b/constant/midjourney.go
index c184435..a5bccb7 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -14,4 +14,5 @@ const (
MjActionInPaint = "INPAINT"
MjActionInPaintPre = "INPAINT_PRE"
MjActionZoom = "ZOOM"
+ MjActionShorten = "SHORTEN"
)
diff --git a/controller/midjourney.go b/controller/midjourney.go
index cac253c..b666e91 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -263,6 +263,10 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
+ if responseItem.Properties != nil {
+ propertiesStr, _ := json.Marshal(responseItem.Properties)
+ task.Properties = string(propertiesStr)
+ }
if responseItem.Buttons != nil {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
diff --git a/controller/relay.go b/controller/relay.go
index a42db2e..7652840 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -68,6 +68,9 @@ func RelayMidjourney(c *gin.Context) {
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") {
// midjourney plus
relayMode = relayconstant.RelayModeMidjourneyModal
+ } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/shorten") {
+ // midjourney plus
+ relayMode = relayconstant.RelayModeMidjourneyShorten
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = relayconstant.RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
diff --git a/dto/midjourney.go b/dto/midjourney.go
index 4fef4e1..d3b19d5 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -22,23 +22,24 @@ type MidjourneyResponse struct {
}
type MidjourneyDto struct {
- MjId string `json:"id"`
- Action string `json:"action"`
- CustomId string `json:"customId"`
- BotType string `json:"botType"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"promptEn"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submitTime"`
- StartTime int64 `json:"startTime"`
- FinishTime int64 `json:"finishTime"`
- ImageUrl string `json:"imageUrl"`
- Status string `json:"status"`
- Progress string `json:"progress"`
- FailReason string `json:"failReason"`
- Buttons any `json:"buttons"`
- MaskBase64 string `json:"maskBase64"`
+ MjId string `json:"id"`
+ Action string `json:"action"`
+ CustomId string `json:"customId"`
+ BotType string `json:"botType"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"promptEn"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submitTime"`
+ StartTime int64 `json:"startTime"`
+ FinishTime int64 `json:"finishTime"`
+ ImageUrl string `json:"imageUrl"`
+ Status string `json:"status"`
+ Progress string `json:"progress"`
+ FailReason string `json:"failReason"`
+ Buttons any `json:"buttons"`
+ MaskBase64 string `json:"maskBase64"`
+ Properties *Properties `json:"properties"`
}
type MidjourneyStatus struct {
@@ -70,3 +71,8 @@ type ActionButton struct {
Type any `json:"type"`
Style any `json:"style"`
}
+
+type Properties struct {
+ FinalPrompt string `json:"finalPrompt"`
+ FinalZhPrompt string `json:"finalZhPrompt"`
+}
diff --git a/model/midjourney.go b/model/midjourney.go
index f20ab32..dd065a3 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -20,6 +20,7 @@ type Midjourney struct {
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
Buttons string `json:"buttons"`
+ Properties string `json:"properties"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index c49caae..d8dc7ee 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -23,6 +23,7 @@ const (
RelayModeAudioTranslation
RelayModeMidjourneyAction
RelayModeMidjourneyModal
+ RelayModeMidjourneyShorten
)
func Path2RelayMode(path string) int {
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index d582055..a1f6ed4 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -31,6 +31,7 @@ var DefaultModelPrice = map[string]float64{
"mj_inpaint_pre": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
+ "swap_face": 0.05,
}
func RelayMidjourneyImage(c *gin.Context) {
@@ -140,6 +141,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Buttons = buttons
}
}
+ if originTask.Properties != "" {
+ var properties dto.Properties
+ err := json.Unmarshal([]byte(originTask.Properties), &properties)
+ if err == nil {
+ midjourneyTask.Properties = &properties
+ }
+ }
return
}
@@ -260,9 +268,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if midjRequest.Prompt == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
}
- midjRequest.Action = "IMAGINE"
+ midjRequest.Action = constant.MjActionImagine
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
- midjRequest.Action = "DESCRIBE"
+ midjRequest.Action = constant.MjActionDescribe
+ } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
+ midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND"
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
@@ -292,7 +302,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
}
mjId = midjRequest.TaskId
- midjRequest.Action = "INPAINT"
+ midjRequest.Action = constant.MjActionInPaint
}
originTask := model.GetByMJId(userId, mjId)
@@ -418,25 +428,16 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer cancel()
resp, err := service.GetHttpClient().Do(req)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 5,
- Description: "do_request_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed")
}
err = req.Body.Close()
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 5,
- Description: "close_request_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
}
err = c.Request.Body.Close()
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 5,
- Description: "close_request_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
}
var midjResponse dto.MidjourneyResponse
@@ -464,33 +465,20 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "read_response_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed")
}
err = resp.Body.Close()
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "close_response_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed")
+ }
+ if resp.StatusCode != 200 {
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status")
}
-
err = json.Unmarshal(responseBody, &midjResponse)
log.Printf("responseBody: %s", string(responseBody))
log.Printf("midjResponse: %v", midjResponse)
- if resp.StatusCode != 200 {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
- }
- }
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "unmarshal_response_body_failed",
- }
+ return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed")
}
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
@@ -651,7 +639,7 @@ func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
} else if strings.Contains(action, "pan") {
midjRequest.Action = constant.MjActionVariation
midjRequest.Index = 1
- } else if action == "Outpaint" || strings.Contains(action, "CustomZoom") {
+ } else if action == "Outpaint" || action == "CustomZoom" {
midjRequest.Action = constant.MjActionZoom
midjRequest.Index = 1
} else if action == "Inpaint" {
diff --git a/router/relay-router.go b/router/relay-router.go
index 68b762b..f572d8f 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -48,6 +48,7 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index a1ffeb6..fe6554e 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -35,6 +35,8 @@ function renderType(type) {
return 图生文;
case 'BLEAND':
return 图混合;
+ case 'SHORTEN':
+ return 缩词;
case 'REROLL':
return 重绘;
case 'INPAINT':
From d5ffaf25027feb11e0564febb4fb31b6d3abdb56 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 18:26:16 +0800
Subject: [PATCH 05/16] =?UTF-8?q?feat:=20=E6=93=8D=E4=BD=9C=E7=BB=86?=
=?UTF-8?q?=E5=88=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
Midjourney.md | 55 ++++++++++++++++++++------
constant/midjourney.go | 22 ++++++-----
controller/relay.go | 32 +--------------
relay/constant/relay_mode.go | 31 +++++++++++++++
relay/relay-mj.go | 66 +++++++++++++------------------
web/src/components/MjLogsTable.js | 6 +++
6 files changed, 122 insertions(+), 90 deletions(-)
diff --git a/Midjourney.md b/Midjourney.md
index becc9c9..d495e84 100644
--- a/Midjourney.md
+++ b/Midjourney.md
@@ -4,31 +4,62 @@
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
+### 模型列表
+
+### midjourney-proxy支持
+
+- mj_imagine (绘图)
+- mj_variation (变换)
+- mj_reroll (重绘)
+- mj_blend (混合)
+- mj_upscale (放大)
+- mj_describe (图生文)
+
+### 仅midjourney-proxy-plus支持
+
+- mj_zoom (比例变焦)
+- mj_shorten (提示词缩短)
+- mj_inpaint_pre (发起局部重绘,必须和mj_inpaint一同添加)
+- mj_inpaint (局部重绘提交,必须和mj_inpaint_pre一同添加)
+- mj_high_variation (强变换)
+- mj_low_variation (弱变换)
+- mj_pan (平移)
+- swap_face (换脸)
+
```json
{
- "gpt-4-gizmo-*": 0.1,
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_inpaint": 0.1,
- "mj_zoom": 0.1,
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_inpaint": 0.1,
+ "mj_zoom": 0.1,
+ "mj_shorten": 0.1,
+ "mj_high_variation": 0.1,
+ "mj_low_variation": 0.1,
+ "mj_pan": 0.1,
"mj_inpaint_pre": 0,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
- "swap_face": 0.05
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
+ "swap_face": 0.05
}
```
## 渠道设置
### 对接 midjourney-proxy(plus)
-1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
-2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
+
+1.
+
+部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
+
+2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
+ ,模型选择midjourney,如果有换脸模型,可以选择swap_face
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
### 对接上游new api
+
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
2. 地址填写上游new api的地址,例如:http://localhost:3000
3. 密钥填写上游new api的密钥
\ No newline at end of file
diff --git a/constant/midjourney.go b/constant/midjourney.go
index a5bccb7..5435a43 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -6,13 +6,17 @@ const (
)
const (
- MjActionImagine = "IMAGINE"
- MjActionDescribe = "DESCRIBE"
- MjActionBlend = "BLEND"
- MjActionUpscale = "UPSCALE"
- MjActionVariation = "VARIATION"
- MjActionInPaint = "INPAINT"
- MjActionInPaintPre = "INPAINT_PRE"
- MjActionZoom = "ZOOM"
- MjActionShorten = "SHORTEN"
+ MjActionImagine = "IMAGINE"
+ MjActionDescribe = "DESCRIBE"
+ MjActionBlend = "BLEND"
+ MjActionUpscale = "UPSCALE"
+ MjActionVariation = "VARIATION"
+ MjActionInPaint = "INPAINT"
+ MjActionInPaintPre = "INPAINT_PRE"
+ MjActionZoom = "ZOOM"
+ MjActionShorten = "SHORTEN"
+ MjActionHighVariation = "HIGH_VARIATION"
+ MjActionLowVariation = "LOW_VARIATION"
+ MjActionPan = "PAN"
+ SwapFace = "SWAP_FACE"
)
diff --git a/controller/relay.go b/controller/relay.go
index 7652840..d35c6a2 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -12,7 +12,6 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv"
- "strings"
)
func Relay(c *gin.Context) {
@@ -61,42 +60,13 @@ func Relay(c *gin.Context) {
}
func RelayMidjourney(c *gin.Context) {
- relayMode := relayconstant.RelayModeUnknown
- if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/action") {
- // midjourney plus
- relayMode = relayconstant.RelayModeMidjourneyAction
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") {
- // midjourney plus
- relayMode = relayconstant.RelayModeMidjourneyModal
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/shorten") {
- // midjourney plus
- relayMode = relayconstant.RelayModeMidjourneyShorten
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
- relayMode = relayconstant.RelayModeMidjourneyImagine
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
- relayMode = relayconstant.RelayModeMidjourneyBlend
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
- relayMode = relayconstant.RelayModeMidjourneyDescribe
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
- relayMode = relayconstant.RelayModeMidjourneyNotify
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
- relayMode = relayconstant.RelayModeMidjourneyChange
- } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
- relayMode = relayconstant.RelayModeMidjourneyChange
- } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
- relayMode = relayconstant.RelayModeMidjourneyTaskFetch
- } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
- relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
- }
-
+ relayMode := constant.Path2RelayModeMidjourney(c.Request.URL.Path)
var err *dto.MidjourneyResponse
switch relayMode {
case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
- //case relayconstant.RelayModeMidjourneyModal:
- // err = relay.RelayMidjournneyModal(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index d8dc7ee..9f13726 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -51,3 +51,34 @@ func Path2RelayMode(path string) int {
}
return relayMode
}
+
+func Path2RelayModeMidjourney(path string) int {
+ relayMode := RelayModeUnknown
+ if strings.HasPrefix(path, "/mj/submit/action") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyAction
+ } else if strings.HasPrefix(path, "/mj/submit/modal") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyModal
+ } else if strings.HasPrefix(path, "/mj/submit/shorten") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyShorten
+ } else if strings.HasPrefix(path, "/mj/submit/imagine") {
+ relayMode = RelayModeMidjourneyImagine
+ } else if strings.HasPrefix(path, "/mj/submit/blend") {
+ relayMode = RelayModeMidjourneyBlend
+ } else if strings.HasPrefix(path, "/mj/submit/describe") {
+ relayMode = RelayModeMidjourneyDescribe
+ } else if strings.HasPrefix(path, "/mj/notify") {
+ relayMode = RelayModeMidjourneyNotify
+ } else if strings.HasPrefix(path, "/mj/submit/change") {
+ relayMode = RelayModeMidjourneyChange
+ } else if strings.HasPrefix(path, "/mj/submit/simple-change") {
+ relayMode = RelayModeMidjourneyChange
+ } else if strings.HasSuffix(path, "/fetch") {
+ relayMode = RelayModeMidjourneyTaskFetch
+ } else if strings.HasSuffix(path, "/list-by-condition") {
+ relayMode = RelayModeMidjourneyTaskFetchByCondition
+ }
+ return relayMode
+}
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index a1f6ed4..f391f14 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -22,16 +22,20 @@ import (
)
var DefaultModelPrice = map[string]float64{
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_inpaint": 0.1,
- "mj_zoom": 0.1,
- "mj_inpaint_pre": 0,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
- "swap_face": 0.05,
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_inpaint": 0.1,
+ "mj_zoom": 0.1,
+ "mj_shorten": 0.1,
+ "mj_high_variation": 0.1,
+ "mj_low_variation": 0.1,
+ "mj_pan": 0.1,
+ "mj_inpaint_pre": 0,
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
+ "swap_face": 0.05,
}
func RelayMidjourneyImage(c *gin.Context) {
@@ -151,31 +155,6 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
return
}
-func RelayMidjournneyModal(c *gin.Context) *dto.MidjourneyResponse {
- userId := c.GetInt("id")
- var midjRequest dto.MidjourneyRequest
- err := common.UnmarshalBodyReusable(c, &midjRequest)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
- }
- originTask := model.GetByMJId(userId, midjRequest.TaskId)
- if originTask == nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
- }
-
- respBody, err := json.Marshal(midjRequest)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
- }
- return nil
-
-}
-
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
userId := c.GetInt("id")
var err error
@@ -274,7 +253,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
- midjRequest.Action = "BLEND"
+ midjRequest.Action = constant.MjActionBlend
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
@@ -634,10 +613,21 @@ func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
midjRequest.Index = index
midjRequest.Action = constant.MjActionUpscale
} else if strings.Contains(action, "variation") {
- midjRequest.Action = constant.MjActionVariation
midjRequest.Index = 1
+ if action == "variation" {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionVariation
+ } else if action == "low_variation" {
+ midjRequest.Action = constant.MjActionLowVariation
+ } else if action == "high_variation" {
+ midjRequest.Action = constant.MjActionHighVariation
+ }
} else if strings.Contains(action, "pan") {
- midjRequest.Action = constant.MjActionVariation
+ midjRequest.Action = constant.MjActionPan
midjRequest.Index = 1
} else if action == "Outpaint" || action == "CustomZoom" {
midjRequest.Action = constant.MjActionZoom
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index fe6554e..603d345 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -31,6 +31,12 @@ function renderType(type) {
return 放大;
case 'VARIATION':
return 变换;
+ case 'HIGH_VARIATION':
+ return 强变换;
+ case 'LOW_VARIATION':
+ return 弱变换;
+ case 'PAN':
+ return 平移;
case 'DESCRIBE':
return 图生文;
case 'BLEAND':
From 3d10c9f090300c5653823556ca224a5d86cb86e7 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 21:19:48 +0800
Subject: [PATCH 06/16] =?UTF-8?q?feat:=20=E5=B0=86=E6=93=8D=E4=BD=9C?=
=?UTF-8?q?=E6=8B=86=E5=88=86=E6=88=90=E5=8D=95=E7=8B=AC=E7=9A=84=E6=A8=A1?=
=?UTF-8?q?=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
common/model-ratio.go | 32 +++--
constant/midjourney.go | 18 +++
controller/midjourney.go | 131 ------------------
controller/model.go | 18 ++-
controller/relay.go | 12 +-
dto/midjourney.go | 7 +
middleware/auth.go | 8 +-
middleware/distributor.go | 129 +++++++++++-------
middleware/utils.go | 12 +-
relay/relay-mj.go | 195 +++++++--------------------
service/midjourney.go | 135 +++++++++++++++++++
web/src/pages/Channel/EditChannel.js | 19 ++-
12 files changed, 366 insertions(+), 350 deletions(-)
create mode 100644 service/midjourney.go
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 791f733..153b748 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -94,17 +94,30 @@ var ModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
}
-var ModelPrice = map[string]float64{
- "gpt-4-gizmo-*": 0.1,
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
+var DefaultModelPrice = map[string]float64{
+ "gpt-4-gizmo-*": 0.1,
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_inpaint": 0.1,
+ "mj_zoom": 0.1,
+ "mj_shorten": 0.1,
+ "mj_high_variation": 0.1,
+ "mj_low_variation": 0.1,
+ "mj_pan": 0.1,
+ "mj_inpaint_pre": 0,
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
+ "swap_face": 0.05,
}
+var ModelPrice = map[string]float64{}
+
func ModelPrice2JSONString() string {
+ if len(ModelPrice) == 0 {
+ ModelPrice = DefaultModelPrice
+ }
jsonBytes, err := json.Marshal(ModelPrice)
if err != nil {
SysError("error marshalling model price: " + err.Error())
@@ -118,6 +131,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
}
func GetModelPrice(name string, printErr bool) float64 {
+ if len(ModelPrice) == 0 {
+ ModelPrice = DefaultModelPrice
+ }
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
diff --git a/constant/midjourney.go b/constant/midjourney.go
index 5435a43..92e2f23 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -11,6 +11,7 @@ const (
MjActionBlend = "BLEND"
MjActionUpscale = "UPSCALE"
MjActionVariation = "VARIATION"
+ MjActionReRoll = "REROLL"
MjActionInPaint = "INPAINT"
MjActionInPaintPre = "INPAINT_PRE"
MjActionZoom = "ZOOM"
@@ -20,3 +21,20 @@ const (
MjActionPan = "PAN"
SwapFace = "SWAP_FACE"
)
+
+var MidjourneyModel2Action = map[string]string{
+ "mj_imagine": MjActionImagine,
+ "mj_describe": MjActionDescribe,
+ "mj_blend": MjActionBlend,
+ "mj_upscale": MjActionUpscale,
+ "mj_variation": MjActionVariation,
+ "mj_reroll": MjActionReRoll,
+ "mj_inpaint": MjActionInPaint,
+ "mj_inpaint_pre": MjActionInPaintPre,
+ "mj_zoom": MjActionZoom,
+ "mj_shorten": MjActionShorten,
+ "mj_high_variation": MjActionHighVariation,
+ "mj_low_variation": MjActionLowVariation,
+ "mj_pan": MjActionPan,
+ "swap_face": SwapFace,
+}
diff --git a/controller/midjourney.go b/controller/midjourney.go
index b666e91..6256471 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -18,137 +18,6 @@ import (
"time"
)
-/*func UpdateMidjourneyTask() {
- //revocer
- //imageModel := "midjourney"
- ctx := context.TODO()
- imageModel := "midjourney"
- defer func() {
- if err := recover(); err != nil {
- log.Printf("UpdateMidjourneyTask panic: %v", err)
- }
- }()
- for {
- time.Sleep(time.Duration(15) * time.Second)
- tasks := model.GetAllUnFinishTasks()
- if len(tasks) != 0 {
- common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
- for _, task := range tasks {
- common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
- midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
- if err != nil {
- common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
- task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
- task.Status = "FAILURE"
- task.Progress = "100%"
- err := task.Update()
- if err != nil {
- common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
- continue
- }
- continue
- }
- requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
- common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
-
- req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
- if err != nil {
- common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
- continue
- }
-
- // 设置超时时间
- timeout := time.Second * 5
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
-
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
-
- req.Header.Set("Content-Type", "application/json")
- //req.Header.Set("ApiKey", "Bearer midjourney-proxy")
- req.Header.Set("mj-api-secret", midjourneyChannel.Key)
- resp, err := httpClient.Do(req)
- if err != nil {
- log.Printf("UpdateMidjourneyTask error: %v", err)
- continue
- }
- responseBody, err := io.ReadAll(resp.Body)
- resp.Body.Close()
- log.Printf("responseBody: %s", string(responseBody))
- var responseItem MidjourneyDto
- // err = json.NewDecoder(resp.Body).Decode(&responseItem)
- err = json.Unmarshal(responseBody, &responseItem)
- if err != nil {
- if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") {
- var responseWithoutStatus MidjourneyWithoutStatus
- var responseStatus MidjourneyStatus
- err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
- err2 := json.Unmarshal(responseBody, &responseStatus)
- if err1 == nil && err2 == nil {
- jsonData, err3 := json.Marshal(responseWithoutStatus)
- if err3 != nil {
- log.Printf("UpdateMidjourneyTask error1: %v", err3)
- continue
- }
- err4 := json.Unmarshal(jsonData, &responseStatus)
- if err4 != nil {
- log.Printf("UpdateMidjourneyTask error2: %v", err4)
- continue
- }
- responseItem.Status = strconv.Itoa(responseStatus.Status)
- } else {
- log.Printf("UpdateMidjourneyTask error3: %v", err)
- continue
- }
- } else {
- log.Printf("UpdateMidjourneyTask error4: %v", err)
- continue
- }
- }
- task.Code = 1
- task.Progress = responseItem.Progress
- task.PromptEn = responseItem.PromptEn
- task.State = responseItem.State
- task.SubmitTime = responseItem.SubmitTime
- task.StartTime = responseItem.StartTime
- task.FinishTime = responseItem.FinishTime
- task.ImageUrl = responseItem.ImageUrl
- task.Status = responseItem.Status
- task.FailReason = responseItem.FailReason
- if task.Progress != "100%" && responseItem.FailReason != "" {
- common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
- task.Progress = "100%"
- err = model.CacheUpdateUserQuota(task.UserId)
- if err != nil {
- log.Println("error update user quota cache: " + err.Error())
- } else {
- modelRatio := common.GetModelRatio(imageModel)
- groupRatio := common.GetGroupRatio("default")
- ratio := modelRatio * groupRatio
- quota := int(ratio * 1 * 1000)
- if quota != 0 {
- err := model.IncreaseUserQuota(task.UserId, quota)
- if err != nil {
- log.Println("fail to increase user quota")
- }
- logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- }
- }
-
- err = task.Update()
- if err != nil {
- log.Printf("UpdateMidjourneyTask error5: %v", err)
- }
- log.Printf("UpdateMidjourneyTask success: %v", task)
- cancel()
- }
- }
- }
-}
-*/
-
func UpdateMidjourneyTaskBulk() {
//imageModel := "midjourney"
ctx := context.TODO()
diff --git a/controller/model.go b/controller/model.go
index 38c6c46..9a106aa 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -4,12 +4,13 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
+ "one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"one-api/relay/channel/ai360"
"one-api/relay/channel/moonshot"
- "one-api/relay/constant"
+ relayconstant "one-api/relay/constant"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -59,8 +60,8 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
- for i := 0; i < constant.APITypeDummy; i++ {
- if i == constant.APITypeAIProxyLibrary {
+ for i := 0; i < relayconstant.APITypeDummy; i++ {
+ if i == relayconstant.APITypeAIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
@@ -100,6 +101,17 @@ func init() {
Parent: nil,
})
}
+ for modelName, _ := range constant.MidjourneyModel2Action {
+ openAIModels = append(openAIModels, OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "midjourney",
+ Permission: permission,
+ Root: modelName,
+ Parent: nil,
+ })
+ }
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
diff --git a/controller/relay.go b/controller/relay.go
index d35c6a2..fa5493a 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -60,7 +60,7 @@ func Relay(c *gin.Context) {
}
func RelayMidjourney(c *gin.Context) {
- relayMode := constant.Path2RelayModeMidjourney(c.Request.URL.Path)
+ relayMode := c.GetInt("relay_mode")
var err *dto.MidjourneyResponse
switch relayMode {
case relayconstant.RelayModeMidjourneyNotify:
@@ -73,13 +73,15 @@ func RelayMidjourney(c *gin.Context) {
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
+ statusCode := http.StatusBadRequest
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+ statusCode = http.StatusTooManyRequests
}
- c.JSON(429, gin.H{
- "error": fmt.Sprintf("%s %s", err.Description, err.Result),
- "type": "upstream_error",
- "code": err.Code,
+ c.JSON(statusCode, gin.H{
+ "description": fmt.Sprintf("%s %s", err.Description, err.Result),
+ "type": "upstream_error",
+ "code": err.Code,
})
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
diff --git a/dto/midjourney.go b/dto/midjourney.go
index d3b19d5..f81756c 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -1,5 +1,12 @@
package dto
+//type SimpleMjRequest struct {
+// Prompt string `json:"prompt"`
+// CustomId string `json:"customId"`
+// Action string `json:"action"`
+// Content string `json:"content"`
+//}
+
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
diff --git a/middleware/auth.go b/middleware/auth.go
index a8dac30..4b865c2 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) {
}
token, err := model.ValidateUserToken(key)
if err != nil {
- abortWithMessage(c, http.StatusUnauthorized, err.Error())
+ abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
- abortWithMessage(c, http.StatusInternalServerError, err.Error())
+ abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
- abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
c.Set("id", token.UserId)
@@ -129,7 +129,7 @@ func TokenAuth() func(c *gin.Context) {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
} else {
- abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
}
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 1ca43dd..c1b3ccb 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -4,7 +4,11 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/constant"
+ "one-api/dto"
"one-api/model"
+ relayconstant "one-api/relay/constant"
+ "one-api/service"
"strconv"
"strings"
@@ -23,32 +27,58 @@ func Distribute() func(c *gin.Context) {
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
- abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
- abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
if channel.Status != common.ChannelStatusEnabled {
- abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
+ shouldSelectChannel := true
// Select a channel for the user
var modelRequest ModelRequest
var err error
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
- // Midjourney
- if modelRequest.Model == "" {
- modelRequest.Model = "midjourney"
+ relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
+ relayMode == relayconstant.RelayModeMidjourneyNotify {
+ shouldSelectChannel = false
+ } else {
+ midjourneyRequest := dto.MidjourneyRequest{}
+ err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
+ if err != nil {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
+ return
+ }
+ midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
+ if mjErr != nil {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
+ return
+ }
+ if midjourneyModel == "" {
+ if !success {
+ abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
+ return
+ } else {
+ // task fetch, task fetch by condition, notify
+ shouldSelectChannel = false
+ }
+ }
+ modelRequest.Model = midjourneyModel
}
+ c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
- abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -87,60 +117,61 @@ func Distribute() func(c *gin.Context) {
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
- abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
// token model limit is empty, all models are not allowed
- abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
}
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
-
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
- if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
- // 如果错误,但是渠道不为空,说明是数据库一致性问题
- if channel != nil {
- common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
- message = "数据库一致性已被破坏,请联系管理员"
+ if shouldSelectChannel {
+ channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+ if err != nil {
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+ // 如果错误,但是渠道不为空,说明是数据库一致性问题
+ if channel != nil {
+ common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+ message = "数据库一致性已被破坏,请联系管理员"
+ }
+ // 如果错误,而且渠道为空,说明是没有可用渠道
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
+ return
+ }
+ if channel == nil {
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
+ return
+ }
+ c.Set("channel", channel.Type)
+ c.Set("channel_id", channel.Id)
+ c.Set("channel_name", channel.Name)
+ ban := true
+ // parse *int to bool
+ if channel.AutoBan != nil && *channel.AutoBan == 0 {
+ ban = false
+ }
+ c.Set("auto_ban", ban)
+ c.Set("model_mapping", channel.GetModelMapping())
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+ c.Set("base_url", channel.GetBaseURL())
+ // TODO: api_version统一
+ switch channel.Type {
+ case common.ChannelTypeAzure:
+ c.Set("api_version", channel.Other)
+ case common.ChannelTypeXunfei:
+ c.Set("api_version", channel.Other)
+ //case common.ChannelTypeAIProxyLibrary:
+ // c.Set("library_id", channel.Other)
+ case common.ChannelTypeGemini:
+ c.Set("api_version", channel.Other)
+ case common.ChannelTypeAli:
+ c.Set("plugin", channel.Other)
}
- // 如果错误,而且渠道为空,说明是没有可用渠道
- abortWithMessage(c, http.StatusServiceUnavailable, message)
- return
}
- if channel == nil {
- abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
- return
- }
- }
- c.Set("channel", channel.Type)
- c.Set("channel_id", channel.Id)
- c.Set("channel_name", channel.Name)
- ban := true
- // parse *int to bool
- if channel.AutoBan != nil && *channel.AutoBan == 0 {
- ban = false
- }
- c.Set("auto_ban", ban)
- c.Set("model_mapping", channel.GetModelMapping())
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- c.Set("base_url", channel.GetBaseURL())
- // TODO: api_version统一
- switch channel.Type {
- case common.ChannelTypeAzure:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeXunfei:
- c.Set("api_version", channel.Other)
- //case common.ChannelTypeAIProxyLibrary:
- // c.Set("library_id", channel.Other)
- case common.ChannelTypeGemini:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeAli:
- c.Set("plugin", channel.Other)
}
c.Next()
}
diff --git a/middleware/utils.go b/middleware/utils.go
index 021002d..43801c1 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -5,7 +5,7 @@ import (
"one-api/common"
)
-func abortWithMessage(c *gin.Context, statusCode int, message string) {
+func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
common.LogError(c.Request.Context(), message)
}
+
+func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
+ c.JSON(statusCode, gin.H{
+ "description": description,
+ "type": "new_api_error",
+ "code": code,
+ })
+ c.Abort()
+ common.LogError(c.Request.Context(), description)
+}
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index f391f14..6cdd9e0 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -21,23 +21,6 @@ import (
"github.com/gin-gonic/gin"
)
-var DefaultModelPrice = map[string]float64{
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_inpaint": 0.1,
- "mj_zoom": 0.1,
- "mj_shorten": 0.1,
- "mj_high_variation": 0.1,
- "mj_low_variation": 0.1,
- "mj_pan": 0.1,
- "mj_inpaint_pre": 0,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
- "swap_face": 0.05,
-}
-
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByOnlyMJId(taskId)
@@ -221,10 +204,9 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
}
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
- imageModel := "midjourney"
tokenId := c.GetInt("token_id")
- channelType := c.GetInt("channel")
+ //channelType := c.GetInt("channel")
userId := c.GetInt("id")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
@@ -236,7 +218,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
- mjErr := coverPlusActionToNormalAction(&midjRequest)
+ mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
if mjErr != nil {
return mjErr
}
@@ -270,11 +252,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if midjRequest.Content == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
}
- params := convertSimpleChangeParams(midjRequest.Content)
+ params := service.ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
}
- mjId = params.ID
+ mjId = params.TaskId
midjRequest.Action = params.Action
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
if midjRequest.MaskBase64 == "" {
@@ -294,18 +276,21 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
}
+ if channel.Status != common.ChannelStatusEnabled {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+ }
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
- if channelType == common.ChannelTypeMidjourneyPlus {
- // plus
- } else {
- // 普通版渠道
-
- }
+ //if channelType == common.ChannelTypeMidjourneyPlus {
+ // // plus
+ //} else {
+ // // 普通版渠道
+ //
+ //}
}
if midjRequest.Action == constant.MjActionInPaintPre {
@@ -313,54 +298,52 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
// map model name
- modelMapping := c.GetString("model_mapping")
- isModelMapped := false
- if modelMapping != "" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "unmarshal_model_mapping_failed",
- }
- }
- if modelMap[imageModel] != "" {
- imageModel = modelMap[imageModel]
- isModelMapped = true
- }
- }
+ //modelMapping := c.GetString("model_mapping")
+ //isModelMapped := false
+ //if modelMapping != "" {
+ // modelMap := make(map[string]string)
+ // err := json.Unmarshal([]byte(modelMapping), &modelMap)
+ // if err != nil {
+ // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+ // return &dto.MidjourneyResponse{
+ // Code: 4,
+ // Description: "unmarshal_model_mapping_failed",
+ // }
+ // }
+ // if modelMap[imageModel] != "" {
+ // imageModel = modelMap[imageModel]
+ // isModelMapped = true
+ // }
+ //}
- baseURL := common.ChannelBaseURLs[channelType]
+ //baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
- if c.GetString("base_url") != "" {
- baseURL = c.GetString("base_url")
- }
+ baseURL := c.GetString("base_url")
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
var requestBody io.Reader
- if isModelMapped {
- jsonStr, err := json.Marshal(midjRequest)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "marshal_text_request_failed",
- }
- }
- requestBody = bytes.NewBuffer(jsonStr)
- } else {
- requestBody = c.Request.Body
- }
+ //if isModelMapped {
+ // jsonStr, err := json.Marshal(midjRequest)
+ // if err != nil {
+ // return &dto.MidjourneyResponse{
+ // Code: 4,
+ // Description: "marshal_text_request_failed",
+ // }
+ // }
+ // requestBody = bytes.NewBuffer(jsonStr)
+ //} else {
+ //}
+ requestBody = c.Request.Body
- mjAction := "mj_" + strings.ToLower(midjRequest.Action)
- modelPrice := common.GetModelPrice(mjAction, true)
+ modelName := service.CoverActionToModelName(midjRequest.Action)
+ modelPrice := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if modelPrice == -1 {
- defaultPrice, ok := DefaultModelPrice[mjAction]
+ defaultPrice, ok := common.DefaultModelPrice[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -433,7 +416,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
@@ -558,85 +541,3 @@ type taskChangeParams struct {
Action string
Index int
}
-
-func convertSimpleChangeParams(content string) *taskChangeParams {
- split := strings.Split(content, " ")
- if len(split) != 2 {
- return nil
- }
-
- action := strings.ToLower(split[1])
- changeParams := &taskChangeParams{}
- changeParams.ID = split[0]
-
- if action[0] == 'u' {
- changeParams.Action = "UPSCALE"
- } else if action[0] == 'v' {
- changeParams.Action = "VARIATION"
- } else if action == "r" {
- changeParams.Action = "REROLL"
- return changeParams
- } else {
- return nil
- }
-
- index, err := strconv.Atoi(action[1:2])
- if err != nil || index < 1 || index > 4 {
- return nil
- }
- changeParams.Index = index
- return changeParams
-}
-
-func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
- // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
- customId := midjRequest.CustomId
- if customId == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
- }
- splits := strings.Split(customId, "::")
- var action string
- if splits[1] == "JOB" {
- action = splits[2]
- } else {
- action = splits[1]
- }
-
- if action == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
- }
- if strings.Contains(action, "upsample") {
- index, err := strconv.Atoi(splits[3])
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
- }
- midjRequest.Index = index
- midjRequest.Action = constant.MjActionUpscale
- } else if strings.Contains(action, "variation") {
- midjRequest.Index = 1
- if action == "variation" {
- index, err := strconv.Atoi(splits[3])
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
- }
- midjRequest.Index = index
- midjRequest.Action = constant.MjActionVariation
- } else if action == "low_variation" {
- midjRequest.Action = constant.MjActionLowVariation
- } else if action == "high_variation" {
- midjRequest.Action = constant.MjActionHighVariation
- }
- } else if strings.Contains(action, "pan") {
- midjRequest.Action = constant.MjActionPan
- midjRequest.Index = 1
- } else if action == "Outpaint" || action == "CustomZoom" {
- midjRequest.Action = constant.MjActionZoom
- midjRequest.Index = 1
- } else if action == "Inpaint" {
- midjRequest.Action = constant.MjActionInPaintPre
- midjRequest.Index = 1
- } else {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
- }
- return nil
-}
diff --git a/service/midjourney.go b/service/midjourney.go
new file mode 100644
index 0000000..06730c8
--- /dev/null
+++ b/service/midjourney.go
@@ -0,0 +1,135 @@
+package service
+
+import (
+ "one-api/constant"
+ "one-api/dto"
+ relayconstant "one-api/relay/constant"
+ "strconv"
+ "strings"
+)
+
+func CoverActionToModelName(mjAction string) string {
+ modelName := "mj_" + strings.ToLower(mjAction)
+ return modelName
+}
+
+func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
+ action := ""
+ if relayMode == relayconstant.RelayModeMidjourneyAction {
+ // plus request
+ err := CoverPlusActionToNormalAction(midjRequest)
+ if err != nil {
+ return "", err, false
+ }
+ action = midjRequest.Action
+ } else {
+ switch relayMode {
+ case relayconstant.RelayModeMidjourneyImagine:
+ action = constant.MjActionImagine
+ case relayconstant.RelayModeMidjourneyDescribe:
+ action = constant.MjActionDescribe
+ case relayconstant.RelayModeMidjourneyBlend:
+ action = constant.MjActionBlend
+ case relayconstant.RelayModeMidjourneyShorten:
+ action = constant.MjActionShorten
+ case relayconstant.RelayModeMidjourneyChange:
+ action = midjRequest.Action
+ case relayconstant.RelayModeMidjourneyModal:
+ action = constant.MjActionInPaint
+ case relayconstant.RelayModeMidjourneySimpleChange:
+ params := ConvertSimpleChangeParams(midjRequest.Content)
+ if params == nil {
+ return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
+ }
+ action = params.Action
+ case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
+ return "", nil, true
+ default:
+ return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false
+ }
+ }
+ modelName := CoverActionToModelName(action)
+ return modelName, nil, true
+}
+
+func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
+ // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
+ customId := midjRequest.CustomId
+ if customId == "" {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
+ }
+ splits := strings.Split(customId, "::")
+ var action string
+ if splits[1] == "JOB" {
+ action = splits[2]
+ } else {
+ action = splits[1]
+ }
+
+ if action == "" {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ }
+ if strings.Contains(action, "upsample") {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionUpscale
+ } else if strings.Contains(action, "variation") {
+ midjRequest.Index = 1
+ if action == "variation" {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionVariation
+ } else if action == "low_variation" {
+ midjRequest.Action = constant.MjActionLowVariation
+ } else if action == "high_variation" {
+ midjRequest.Action = constant.MjActionHighVariation
+ }
+ } else if strings.Contains(action, "pan") {
+ midjRequest.Action = constant.MjActionPan
+ midjRequest.Index = 1
+ } else if action == "Outpaint" || action == "CustomZoom" {
+ midjRequest.Action = constant.MjActionZoom
+ midjRequest.Index = 1
+ } else if action == "Inpaint" {
+ midjRequest.Action = constant.MjActionInPaintPre
+ midjRequest.Index = 1
+ } else {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ }
+ return nil
+}
+
+func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
+ split := strings.Split(content, " ")
+ if len(split) != 2 {
+ return nil
+ }
+
+ action := strings.ToLower(split[1])
+ changeParams := &dto.MidjourneyRequest{}
+ changeParams.TaskId = split[0]
+
+ if action[0] == 'u' {
+ changeParams.Action = "UPSCALE"
+ } else if action[0] == 'v' {
+ changeParams.Action = "VARIATION"
+ } else if action == "r" {
+ changeParams.Action = "REROLL"
+ return changeParams
+ } else {
+ return nil
+ }
+
+ index, err := strconv.Atoi(action[1:2])
+ if err != nil || index < 1 || index > 4 {
+ return nil
+ }
+ changeParams.Index = index
+ return changeParams
+}
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index ee79368..225ce3f 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -96,10 +96,25 @@ const EditChannel = (props) => {
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
break;
case 2:
- localModels = ['midjourney'];
+ localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe'];
break;
case 5:
- localModels = ['midjourney'];
+ localModels = [
+ 'swap_face',
+ 'mj_imagine',
+ 'mj_variation',
+ 'mj_reroll',
+ 'mj_blend',
+ 'mj_upscale',
+ 'mj_describe',
+ 'mj_zoom',
+ 'mj_shorten',
+ 'mj_inpaint_pre',
+ 'mj_inpaint_pre',
+ 'mj_high_variation',
+ 'mj_low_variation',
+ 'mj_pan',
+ ];
break;
}
setInputs((inputs) => ({...inputs, models: localModels}));
From d3399d68f6aae513709bc15646469c1ae652c93c Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 22:24:02 +0800
Subject: [PATCH 07/16] fix: fix typo
---
web/src/components/MjLogsTable.js | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 603d345..88da1cd 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -39,7 +39,7 @@ function renderType(type) {
return 平移;
case 'DESCRIBE':
return 图生文;
- case 'BLEAND':
+ case 'BLEND':
return 图混合;
case 'SHORTEN':
return 缩词;
From 9b2e5c2978721ba6273a5c7b3b33ba9a2a1f3503 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 22:30:10 +0800
Subject: [PATCH 08/16] refactor: remove consumeQuota
---
relay/relay-image.go | 77 ++++++++++++++++++++------------------------
1 file changed, 35 insertions(+), 42 deletions(-)
diff --git a/relay/relay-image.go b/relay/relay-image.go
index 3065496..aabe4ba 100644
--- a/relay/relay-image.go
+++ b/relay/relay-image.go
@@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
- consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
startTime := time.Now()
var imageRequest dto.ImageRequest
- if consumeQuota {
- err := common.UnmarshalBodyReusable(c, &imageRequest)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
- }
+ err := common.UnmarshalBodyReusable(c, &imageRequest)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageRequest.Model == "" {
@@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
- if consumeQuota && userQuota-quota < 0 {
+ if userQuota-quota < 0 {
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
@@ -176,46 +173,42 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
var textResponse dto.ImageResponse
defer func(ctx context.Context) {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
- if consumeQuota {
- if resp.StatusCode != http.StatusOK {
- return
- }
- err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
- if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
- }
- err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- channelId := c.GetInt("channel_id")
- model.UpdateChannelUsedQuota(channelId, quota)
- }
+ if resp.StatusCode != http.StatusOK {
+ return
+ }
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.SysError("error update user quota cache: " + err.Error())
+ }
+ if quota != 0 {
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ channelId := c.GetInt("channel_id")
+ model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
- if consumeQuota {
- responseBody, err := io.ReadAll(resp.Body)
+ responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
- }
- err = json.Unmarshal(responseBody, &textResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
- }
-
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
+ err = resp.Body.Close()
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ }
+ err = json.Unmarshal(responseBody, &textResponse)
+ if err != nil {
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ }
+
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
From 44361d75e88d066a8554dfeff384606d19df2d29 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Wed, 13 Mar 2024 23:17:12 +0800
Subject: [PATCH 09/16] fix: "Inpaint" code error
---
relay/relay-mj.go | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 6cdd9e0..7342717 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -494,8 +494,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
//修改返回值
- newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
- responseBody = []byte(newBody)
+ if midjRequest.Action != constant.MjActionInPaintPre {
+ newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
+ responseBody = []byte(newBody)
+ }
}
err = midjourneyTask.Insert()
From a77fbc0fa214184391224c12340da09d63acce7f Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Thu, 14 Mar 2024 00:43:32 +0800
Subject: [PATCH 10/16] fix: reroll action error
---
model/ability.go | 7 ++++++-
service/midjourney.go | 7 +++++--
web/src/pages/Channel/EditChannel.js | 2 +-
3 files changed, 12 insertions(+), 4 deletions(-)
diff --git a/model/ability.go b/model/ability.go
index 7a81cc2..b79978d 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -147,7 +147,12 @@ func FixAbility() (int, error) {
return 0, err
}
var channels []Channel
- err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
+
+ if len(abilityChannelIds) == 0 {
+ err = DB.Find(&channels).Error
+ } else {
+ err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
+ }
if err != nil {
return 0, err
}
diff --git a/service/midjourney.go b/service/midjourney.go
index 06730c8..c04c4d3 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -45,7 +45,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
return "", nil, true
default:
- return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false
+ return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
}
}
modelName := CoverActionToModelName(action)
@@ -93,6 +93,9 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
} else if strings.Contains(action, "pan") {
midjRequest.Action = constant.MjActionPan
midjRequest.Index = 1
+ } else if strings.Contains(action, "reroll") {
+ midjRequest.Action = constant.MjActionReRoll
+ midjRequest.Index = 1
} else if action == "Outpaint" || action == "CustomZoom" {
midjRequest.Action = constant.MjActionZoom
midjRequest.Index = 1
@@ -100,7 +103,7 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
midjRequest.Action = constant.MjActionInPaintPre
midjRequest.Index = 1
} else {
- return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
}
return nil
}
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index 225ce3f..757b56c 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -110,7 +110,7 @@ const EditChannel = (props) => {
'mj_zoom',
'mj_shorten',
'mj_inpaint_pre',
- 'mj_inpaint_pre',
+ 'mj_inpaint',
'mj_high_variation',
'mj_low_variation',
'mj_pan',
From 614220a0fb191c51e1f1bda5e9497536fe952e23 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Thu, 14 Mar 2024 15:16:36 +0800
Subject: [PATCH 11/16] =?UTF-8?q?feat:=20=E8=B6=85=E8=BF=87=E4=B8=80?=
=?UTF-8?q?=E5=B0=8F=E6=97=B6=E7=9A=84=E4=BB=BB=E5=8A=A1=E8=87=AA=E5=8A=A8?=
=?UTF-8?q?=E5=A4=B1=E8=B4=A5?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
controller/midjourney.go | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 6256471..41db4bf 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -118,10 +118,16 @@ func UpdateMidjourneyTaskBulk() {
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
+
+ useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
+ // 如果时间超过一小时,且进度不是100%,则认为任务失败
+ if useTime > 3600000 && task.Progress != "100%" {
+ responseItem.FailReason = "上游任务超时(超过1小时)"
+ responseItem.Status = "FAILURE"
+ }
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
-
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
@@ -140,6 +146,7 @@ func UpdateMidjourneyTaskBulk() {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
+
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
From d704902b70d34dafcadc1f54cca75bb2c986f7ef Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Thu, 14 Mar 2024 16:42:37 +0800
Subject: [PATCH 12/16] =?UTF-8?q?feat:=20=E5=85=BC=E5=AE=B9=E8=87=AA?=
=?UTF-8?q?=E5=AE=9A=E4=B9=89=E5=8F=98=E7=84=A6=EF=BC=8C=E5=AE=8C=E5=96=84?=
=?UTF-8?q?modal=E6=93=8D=E4=BD=9C?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
common/model-ratio.go | 5 +-
constant/midjourney.go | 6 +-
controller/relay.go | 2 +
dto/midjourney.go | 5 +
middleware/distributor.go | 3 +-
relay/constant/relay_mode.go | 3 +
relay/relay-mj.go | 139 ++++++++-------------------
router/relay-router.go | 1 +
service/error.go | 7 ++
service/midjourney.go | 73 +++++++++++++-
web/src/components/MjLogsTable.js | 10 +-
web/src/pages/Channel/EditChannel.js | 3 +-
12 files changed, 147 insertions(+), 110 deletions(-)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 153b748..3231d95 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -100,13 +100,14 @@ var DefaultModelPrice = map[string]float64{
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
- "mj_inpaint": 0.1,
+ "mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
- "mj_inpaint_pre": 0,
+ "mj_inpaint": 0,
+ "mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
diff --git a/constant/midjourney.go b/constant/midjourney.go
index 92e2f23..f4ae2e4 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -13,8 +13,9 @@ const (
MjActionVariation = "VARIATION"
MjActionReRoll = "REROLL"
MjActionInPaint = "INPAINT"
- MjActionInPaintPre = "INPAINT_PRE"
+ MjActionModal = "MODAL"
MjActionZoom = "ZOOM"
+ MjActionCustomZoom = "CUSTOM_ZOOM"
MjActionShorten = "SHORTEN"
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
@@ -29,9 +30,10 @@ var MidjourneyModel2Action = map[string]string{
"mj_upscale": MjActionUpscale,
"mj_variation": MjActionVariation,
"mj_reroll": MjActionReRoll,
+ "mj_modal": MjActionModal,
"mj_inpaint": MjActionInPaint,
- "mj_inpaint_pre": MjActionInPaintPre,
"mj_zoom": MjActionZoom,
+ "mj_custom_zoom": MjActionCustomZoom,
"mj_shorten": MjActionShorten,
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
diff --git a/controller/relay.go b/controller/relay.go
index fa5493a..e31679d 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -67,6 +67,8 @@ func RelayMidjourney(c *gin.Context) {
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
+ case relayconstant.RelayModeMidjourneyTaskImageSeed:
+ err = relay.RelayMidjourneyTaskImageSeed(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
diff --git a/dto/midjourney.go b/dto/midjourney.go
index f81756c..d3c3583 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -28,6 +28,11 @@ type MidjourneyResponse struct {
Result string `json:"result"`
}
+type MidjourneyResponseWithStatusCode struct {
+ StatusCode int `json:"statusCode"`
+ Response MidjourneyResponse
+}
+
type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
diff --git a/middleware/distributor.go b/middleware/distributor.go
index c1b3ccb..ed457a3 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -48,7 +48,8 @@ func Distribute() func(c *gin.Context) {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
- relayMode == relayconstant.RelayModeMidjourneyNotify {
+ relayMode == relayconstant.RelayModeMidjourneyNotify ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index 9f13726..197efdc 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -17,6 +17,7 @@ const (
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
+ RelayModeMidjourneyTaskImageSeed
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
@@ -77,6 +78,8 @@ func Path2RelayModeMidjourney(path string) int {
relayMode = RelayModeMidjourneyChange
} else if strings.HasSuffix(path, "/fetch") {
relayMode = RelayModeMidjourneyTaskFetch
+ } else if strings.HasSuffix(path, "/image-seed") {
+ relayMode = RelayModeMidjourneyTaskImageSeed
} else if strings.HasSuffix(path, "/list-by-condition") {
relayMode = RelayModeMidjourneyTaskFetchByCondition
}
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 7342717..8eebaeb 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -138,6 +138,31 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
return
}
+func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
+ //taskId := c.Param("id")
+ //userId := c.GetInt("id")
+ //originTask := model.GetByMJId(userId, taskId)
+ //if originTask == nil {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+ //}
+ //channel, err := model.GetChannelById(originTask.ChannelId, false)
+ //if err != nil {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+ //}
+ //if channel.Status != common.ChannelStatusEnabled {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+ //}
+ //c.Set("channel_id", originTask.ChannelId)
+ //requestURL := c.Request.URL.String()
+ //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
+ //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
+ //if err != nil {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
+ //}
+ log.Println("RelayMidjourneyTaskImageSeed")
+ return nil
+}
+
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
userId := c.GetInt("id")
var err error
@@ -259,11 +284,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
mjId = params.TaskId
midjRequest.Action = params.Action
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
- if midjRequest.MaskBase64 == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
- }
+ //if midjRequest.MaskBase64 == "" {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
+ //}
mjId = midjRequest.TaskId
- midjRequest.Action = constant.MjActionInPaint
+ midjRequest.Action = constant.MjActionModal
}
originTask := model.GetByMJId(userId, mjId)
@@ -293,29 +318,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//}
}
- if midjRequest.Action == constant.MjActionInPaintPre {
+ if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
consumeQuota = false
}
- // map model name
- //modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- //if modelMapping != "" {
- // modelMap := make(map[string]string)
- // err := json.Unmarshal([]byte(modelMapping), &modelMap)
- // if err != nil {
- // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- // return &dto.MidjourneyResponse{
- // Code: 4,
- // Description: "unmarshal_model_mapping_failed",
- // }
- // }
- // if modelMap[imageModel] != "" {
- // imageModel = modelMap[imageModel]
- // isModelMapped = true
- // }
- //}
-
//baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -325,20 +331,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- var requestBody io.Reader
- //if isModelMapped {
- // jsonStr, err := json.Marshal(midjRequest)
- // if err != nil {
- // return &dto.MidjourneyResponse{
- // Code: 4,
- // Description: "marshal_text_request_failed",
- // }
- // }
- // requestBody = bytes.NewBuffer(jsonStr)
- //} else {
- //}
- requestBody = c.Request.Body
-
modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
@@ -368,40 +360,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "create_request_failed",
- }
+ return &midjResponseWithStatus.Response
}
- //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
- timeout := time.Second * 30
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
- // print request header
- //log.Printf("request header: %s", req.Header)
- //log.Printf("request body: %s", midjRequest.Prompt)
-
- defer cancel()
- resp, err := service.GetHttpClient().Do(req)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed")
- }
-
- err = req.Body.Close()
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
- }
- err = c.Request.Body.Close()
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
- }
- var midjResponse dto.MidjourneyResponse
+ midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
if consumeQuota {
@@ -424,25 +387,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}(c.Request.Context())
- responseBody, err := io.ReadAll(resp.Body)
-
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed")
- }
- err = resp.Body.Close()
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed")
- }
- if resp.StatusCode != 200 {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status")
- }
- err = json.Unmarshal(responseBody, &midjResponse)
- log.Printf("responseBody: %s", string(responseBody))
- log.Printf("midjResponse: %v", midjResponse)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed")
- }
-
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
@@ -494,7 +438,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
//修改返回值
- if midjRequest.Action != constant.MjActionInPaintPre {
+ if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
@@ -514,21 +458,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody)
}
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
+ //for k, v := range resp.Header {
+ // c.Writer.Header().Set(k, v[0])
+ //}
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
+ _, err = io.Copy(c.Writer, bodyReader)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
- err = resp.Body.Close()
+ err = bodyReader.Close()
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
diff --git a/router/relay-router.go b/router/relay-router.go
index f572d8f..3c6910a 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -57,6 +57,7 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
+ relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
}
//relayMjRouter.Use()
diff --git a/service/error.go b/service/error.go
index 91c78c8..424be5d 100644
--- a/service/error.go
+++ b/service/error.go
@@ -18,6 +18,13 @@ func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
}
}
+func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
+ return &dto.MidjourneyResponseWithStatusCode{
+ StatusCode: statusCode,
+ Response: *MidjourneyErrorWrapper(code, desc),
+ }
+}
+
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
diff --git a/service/midjourney.go b/service/midjourney.go
index c04c4d3..17e54e1 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -1,11 +1,18 @@
package service
import (
+ "context"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "io"
+ "log"
+ "net/http"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"strconv"
"strings"
+ "time"
)
func CoverActionToModelName(mjAction string) string {
@@ -35,7 +42,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
case relayconstant.RelayModeMidjourneyChange:
action = midjRequest.Action
case relayconstant.RelayModeMidjourneyModal:
- action = constant.MjActionInPaint
+ action = constant.MjActionModal
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
@@ -96,11 +103,14 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
} else if strings.Contains(action, "reroll") {
midjRequest.Action = constant.MjActionReRoll
midjRequest.Index = 1
- } else if action == "Outpaint" || action == "CustomZoom" {
+ } else if action == "Outpaint" {
midjRequest.Action = constant.MjActionZoom
midjRequest.Index = 1
+ } else if action == "CustomZoom" {
+ midjRequest.Action = constant.MjActionCustomZoom
+ midjRequest.Index = 1
} else if action == "Inpaint" {
- midjRequest.Action = constant.MjActionInPaintPre
+ midjRequest.Action = constant.MjActionInPaint
midjRequest.Index = 1
} else {
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
@@ -136,3 +146,60 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
changeParams.Index = index
return changeParams
}
+
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+ var nullBytes []byte
+ var requestBody io.Reader
+ requestBody = c.Request.Body
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ // 使用带有超时的 context 创建新的请求
+ req = req.WithContext(ctx)
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
+ defer cancel()
+ resp, err := GetHttpClient().Do(req)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ statusCode := resp.StatusCode
+ //if statusCode != 200 {
+ // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
+ //}
+ err = req.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+ }
+ err = c.Request.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+ }
+ var midjResponse dto.MidjourneyResponse
+
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
+ }
+
+ err = json.Unmarshal(responseBody, &midjResponse)
+ log.Printf("responseBody: %s", string(responseBody))
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
+ }
+ //log.Printf("midjResponse: %v", midjResponse)
+ //for k, v := range resp.Header {
+ // c.Writer.Header().Set(k, v[0])
+ //}
+ return &dto.MidjourneyResponseWithStatusCode{
+ StatusCode: statusCode,
+ Response: midjResponse,
+ }, responseBody, nil
+}
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 88da1cd..4843b2f 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -46,11 +46,13 @@ function renderType(type) {
case 'REROLL':
return 重绘;
case 'INPAINT':
- return 局部重绘;
+ return 局部重绘-提交;
case 'ZOOM':
return 变焦;
- case 'INPAINT_PRE':
- return 局部重绘-预处理;
+ case 'CUSTOM_ZOOM':
+ return 自定义变焦-提交;
+ case 'MODAL':
+ return 窗口处理;
default:
return 未知;
}
@@ -62,7 +64,7 @@ function renderCode(code) {
case 1:
return 已提交;
case 21:
- return 排队中;
+ return 等待中;
case 22:
return 重复提交;
default:
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index 757b56c..ccc18aa 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -109,8 +109,9 @@ const EditChannel = (props) => {
'mj_describe',
'mj_zoom',
'mj_shorten',
- 'mj_inpaint_pre',
+ 'mj_modal',
'mj_inpaint',
+ 'mj_custom_zoom',
'mj_high_variation',
'mj_low_variation',
'mj_pan',
From bc5a54df59ea464186390b95aa56cfbd5c605d55 Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Thu, 14 Mar 2024 16:59:46 +0800
Subject: [PATCH 13/16] feat: support image-seed (close #86)
---
relay/relay-mj.go | 56 ++++++++++++++++++++++++++++-------------------
1 file changed, 34 insertions(+), 22 deletions(-)
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 8eebaeb..6185d5b 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -139,27 +139,38 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
}
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
- //taskId := c.Param("id")
- //userId := c.GetInt("id")
- //originTask := model.GetByMJId(userId, taskId)
- //if originTask == nil {
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
- //}
- //channel, err := model.GetChannelById(originTask.ChannelId, false)
- //if err != nil {
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
- //}
- //if channel.Status != common.ChannelStatusEnabled {
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
- //}
- //c.Set("channel_id", originTask.ChannelId)
- //requestURL := c.Request.URL.String()
- //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
- //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
- //if err != nil {
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
- //}
- log.Println("RelayMidjourneyTaskImageSeed")
+ taskId := c.Param("id")
+ userId := c.GetInt("id")
+ originTask := model.GetByMJId(userId, taskId)
+ if originTask == nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+ }
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+ }
+ if channel.Status != common.ChannelStatusEnabled {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+ }
+ c.Set("channel_id", originTask.ChannelId)
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+ requestURL := c.Request.URL.String()
+ fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
+ midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
+ if err != nil {
+ return &midjResponseWithStatus.Response
+ }
+ midjResponse := &midjResponseWithStatus.Response
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
+ respBody, err := json.Marshal(midjResponse)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+ }
return nil
}
@@ -297,7 +308,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
- channel, err := model.GetChannelById(originTask.ChannelId, false)
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
}
@@ -306,6 +317,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
From 9b5353a81a4f993f519641615c960d712b546d1c Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Thu, 14 Mar 2024 18:08:12 +0800
Subject: [PATCH 14/16] feat: support InsightFace (close #60)
---
Midjourney.md | 10 ++-
constant/midjourney.go | 4 +-
controller/relay.go | 2 +
dto/midjourney.go | 5 ++
relay/constant/relay_mode.go | 4 +
relay/relay-mj.go | 129 +++++++++++++++++++++++++++++-
router/relay-router.go | 1 +
service/midjourney.go | 27 ++++++-
web/src/components/MjLogsTable.js | 4 +
9 files changed, 173 insertions(+), 13 deletions(-)
diff --git a/Midjourney.md b/Midjourney.md
index d495e84..5733a11 100644
--- a/Midjourney.md
+++ b/Midjourney.md
@@ -19,8 +19,9 @@
- mj_zoom (比例变焦)
- mj_shorten (提示词缩短)
-- mj_inpaint_pre (发起局部重绘,必须和mj_inpaint一同添加)
-- mj_inpaint (局部重绘提交,必须和mj_inpaint_pre一同添加)
+- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
+- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
+- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
- mj_high_variation (强变换)
- mj_low_variation (弱变换)
- mj_pan (平移)
@@ -32,13 +33,14 @@
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
- "mj_inpaint": 0.1,
+ "mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
- "mj_inpaint_pre": 0,
+ "mj_inpaint": 0,
+ "mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05
diff --git a/constant/midjourney.go b/constant/midjourney.go
index f4ae2e4..3d321ca 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -20,7 +20,7 @@ const (
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
- SwapFace = "SWAP_FACE"
+ MjActionSwapFace = "SWAP_FACE"
)
var MidjourneyModel2Action = map[string]string{
@@ -38,5 +38,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
- "swap_face": SwapFace,
+ "swap_face": MjActionSwapFace,
}
diff --git a/controller/relay.go b/controller/relay.go
index e31679d..9f866b8 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -69,6 +69,8 @@ func RelayMidjourney(c *gin.Context) {
err = relay.RelayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c)
+ case relayconstant.RelayModeSwapFace:
+ err = relay.RelaySwapFace(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
diff --git a/dto/midjourney.go b/dto/midjourney.go
index d3c3583..c675f7e 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -7,6 +7,11 @@ package dto
// Content string `json:"content"`
//}
+type SwapFaceRequest struct {
+ SourceBase64 string `json:"sourceBase64"`
+ TargetBase64 string `json:"targetBase64"`
+}
+
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index 197efdc..1790c57 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -25,6 +25,7 @@ const (
RelayModeMidjourneyAction
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
+ RelayModeSwapFace
)
func Path2RelayMode(path string) int {
@@ -64,6 +65,9 @@ func Path2RelayModeMidjourney(path string) int {
} else if strings.HasPrefix(path, "/mj/submit/shorten") {
// midjourney plus
relayMode = RelayModeMidjourneyShorten
+ } else if strings.HasPrefix(path, "/mj/insight-face/swap") {
+ // midjourney plus
+ relayMode = RelayModeSwapFace
} else if strings.HasPrefix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasPrefix(path, "/mj/submit/blend") {
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 6185d5b..01ae0c8 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
return
}
+func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
+ startTime := time.Now().UnixNano() / int64(time.Millisecond)
+ tokenId := c.GetInt("token_id")
+ userId := c.GetInt("id")
+ group := c.GetString("group")
+ channelId := c.GetInt("channel_id")
+ var swapFaceRequest dto.SwapFaceRequest
+ err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+ if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
+ }
+ modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
+ modelPrice := common.GetModelPrice(modelName, true)
+ // 如果没有配置价格,则使用默认价格
+ if modelPrice == -1 {
+ defaultPrice, ok := common.DefaultModelPrice[modelName]
+ if !ok {
+ modelPrice = 0.1
+ } else {
+ modelPrice = defaultPrice
+ }
+ }
+ groupRatio := common.GetGroupRatio(group)
+ ratio := modelPrice * groupRatio
+ userQuota, err := model.CacheGetUserQuota(userId)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: err.Error(),
+ }
+ }
+ quota := int(ratio * common.QuotaPerUnit)
+
+ if userQuota-quota < 0 {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "quota_not_enough",
+ }
+ }
+ requestURL := c.Request.URL.String()
+ baseURL := c.GetString("base_url")
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+ mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL)
+ if err != nil {
+ return &mjResp.Response
+ }
+ defer func(ctx context.Context) {
+ if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(userId)
+ if err != nil {
+ common.SysError("error update user quota cache: " + err.Error())
+ }
+ if quota != 0 {
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ channelId := c.GetInt("channel_id")
+ model.UpdateChannelUsedQuota(channelId, quota)
+ }
+ }
+ }(c.Request.Context())
+ midjResponse := &mjResp.Response
+ midjourneyTask := &model.Midjourney{
+ UserId: userId,
+ Code: midjResponse.Code,
+ Action: constant.MjActionSwapFace,
+ MjId: midjResponse.Result,
+ Prompt: "swap_face",
+ PromptEn: "",
+ Description: midjResponse.Description,
+ State: "",
+ SubmitTime: startTime,
+ StartTime: time.Now().UnixNano() / int64(time.Millisecond),
+ FinishTime: 0,
+ ImageUrl: "",
+ Status: "",
+ Progress: "0%",
+ FailReason: "",
+ ChannelId: c.GetInt("channel_id"),
+ Quota: quota,
+ }
+ err = midjourneyTask.Insert()
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
+ }
+ c.Writer.WriteHeader(mjResp.StatusCode)
+ respBody, err := json.Marshal(midjResponse)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+ }
+ return nil
+}
+
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
taskId := c.Param("id")
userId := c.GetInt("id")
@@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
requestURL := c.Request.URL.String()
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
- midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
+ midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
if err != nil {
return &midjResponseWithStatus.Response
}
+ //defer func(ctx context.Context) {
+ // err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+ // if err != nil {
+ // common.SysError("error consuming token remain quota: " + err.Error())
+ // }
+ // err = model.CacheUpdateUserQuota(userId)
+ // if err != nil {
+ // common.SysError("error update user quota cache: " + err.Error())
+ // }
+ // if quota != 0 {
+ // tokenName := c.GetString("token_name")
+ // logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
+ // model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+ // model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+ // channelId := c.GetInt("channel_id")
+ // model.UpdateChannelUsedQuota(channelId, quota)
+ // }
+ //}(c.Request.Context())
midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse)
@@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
- midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
if err != nil {
return &midjResponseWithStatus.Response
}
midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
- if consumeQuota {
+ if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
diff --git a/router/relay-router.go b/router/relay-router.go
index 3c6910a..4addee0 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -59,6 +59,7 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
+ relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
}
//relayMjRouter.Use()
}
diff --git a/service/midjourney.go b/service/midjourney.go
index 17e54e1..7c47cd6 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -17,6 +17,9 @@ import (
func CoverActionToModelName(mjAction string) string {
modelName := "mj_" + strings.ToLower(mjAction)
+ if mjAction == constant.MjActionSwapFace {
+ modelName = "swap_face"
+ }
return modelName
}
@@ -43,6 +46,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
action = midjRequest.Action
case relayconstant.RelayModeMidjourneyModal:
action = constant.MjActionModal
+ case relayconstant.RelayModeSwapFace:
+ action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
@@ -147,11 +152,25 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
return changeParams
}
-func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
var nullBytes []byte
- var requestBody io.Reader
- requestBody = c.Request.Body
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ //var requestBody io.Reader
+ //requestBody = c.Request.Body
+ // read request body to json, delete accountFilter and notifyHook
+ var mapResult map[string]interface{}
+ err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ delete(mapResult, "accountFilter")
+ delete(mapResult, "notifyHook")
+ //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ // make new request with mapResult
+ reqBody, err := json.Marshal(mapResult)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
}
diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js
index 4843b2f..90c55f1 100644
--- a/web/src/components/MjLogsTable.js
+++ b/web/src/components/MjLogsTable.js
@@ -53,6 +53,8 @@ function renderType(type) {
return 自定义变焦-提交;
case 'MODAL':
return 窗口处理;
+ case 'SWAP_FACE':
+ return 换脸;
default:
return 未知;
}
@@ -67,6 +69,8 @@ function renderCode(code) {
return 等待中;
case 22:
return 重复提交;
+ case 0:
+ return 未提交;
default:
return 未知;
}
From 2786a6b53931230d128502a6d2b01d18c798e13d Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Thu, 14 Mar 2024 18:10:09 +0800
Subject: [PATCH 15/16] Update README.md
---
README.md | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index f1d18da..ce3f27f 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@
此分叉版本的主要变更如下:
1. 全新的UI界面(部分界面还待更新)
-2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持
+2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
@@ -26,6 +26,11 @@
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
+ [x] /task/list-by-condition
+ + [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
+ + [x] /mj/submit/modal
+ + [x] /mj/submit/shorten
+ + [x] /mj/task/{id}/image-seed
+ + [x] /mj/insight-face/swap (InsightFace)
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
4. 支持用key查询使用额度:
@@ -49,6 +54,7 @@
2. 智谱glm-4v,glm-4v识图
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
+5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
From 84e0544604cc3a23bba80bd834b5314710fab1ee Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Thu, 14 Mar 2024 18:19:22 +0800
Subject: [PATCH 16/16] =?UTF-8?q?refactor:=20=E4=BF=AE=E6=94=B9=E8=B6=85?=
=?UTF-8?q?=E6=97=B6=E6=97=B6=E9=97=B4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
relay/relay-mj.go | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 01ae0c8..35353b4 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -183,7 +183,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
requestURL := c.Request.URL.String()
baseURL := c.GetString("base_url")
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL)
+ mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
if err != nil {
return &mjResp.Response
}
@@ -213,7 +213,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
Code: midjResponse.Code,
Action: constant.MjActionSwapFace,
MjId: midjResponse.Result,
- Prompt: "swap_face",
+ Prompt: "InsightFace",
PromptEn: "",
Description: midjResponse.Description,
State: "",
@@ -495,7 +495,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
- midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
if err != nil {
return &midjResponseWithStatus.Response
}