From 976da45bcef1139bf596b692804a8bc2128f7850 Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 20 Feb 2024 11:23:55 +0800 Subject: [PATCH] feat: allow user config third-party platform openai and mj api key --- api/core/types/config.go | 1 + api/service/mj/plus/client.go | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/api/core/types/config.go b/api/core/types/config.go index 4fcb581d..bb71dd8d 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -65,6 +65,7 @@ type StableDiffusionConfig struct { type MidJourneyPlusConfig struct { Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 ApiURL string // api 地址 + Mode string // 绘画模式,可选值:fast/turbo/relax CdnURL string // CDN 加速地址 ApiKey string NotifyURL string // 任务进度更新回调地址 diff --git a/api/service/mj/plus/client.go b/api/service/mj/plus/client.go index b2035929..6dc03af0 100644 --- a/api/service/mj/plus/client.go +++ b/api/service/mj/plus/client.go @@ -7,11 +7,10 @@ import ( "encoding/base64" "errors" "fmt" + "github.com/imroc/req/v3" "io" "github.com/gin-gonic/gin" - - "github.com/imroc/req/v3" ) var logger = logger2.GetLogger() @@ -29,6 +28,9 @@ func NewClient(config types.MidJourneyPlusConfig) *Client { } else { apiURL = config.ApiURL } + if config.Mode == "" { + config.Mode = "fast" + } return &Client{Config: config, apiURL: apiURL} } @@ -62,7 +64,7 @@ type ErrRes struct { } func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { - apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.apiURL) + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) body := ImageReq{ BotType: "MID_JOURNEY", Prompt: task.Prompt, @@ -101,7 +103,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { // Blend 融图 func (c *Client) Blend(task types.MjTask) (ImageRes, error) { - apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/blend", c.apiURL) + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) body := ImageReq{ BotType: "MID_JOURNEY", Dimensions: "SQUARE", @@ -141,7 +143,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) { // SwapFace 换脸 func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { - apiURL := fmt.Sprintf("%s/mj-fast/mj/insight-face/swap", c.apiURL) + apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) // 生成图片 Base64 编码 if len(task.ImgArr) != 2 { return ImageRes{}, errors.New("参数错误,必须上传2张图片")