check if the api url in whitelist for mj plus client

This commit is contained in:
RockYang 2024-05-22 11:47:04 +08:00
parent 5742b40aee
commit 6944a32ff3
4 changed files with 43 additions and 15 deletions

View File

@ -146,7 +146,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
Model: "dall-e-3", Model: "dall-e-3",
Prompt: prompt, Prompt: prompt,
N: 1, N: 1,
Size: "1024x1024", Size: task.Size,
Style: task.Style, Style: task.Style,
Quality: task.Quality, Quality: task.Quality,
}). }).

View File

@ -12,6 +12,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
"geekai/service"
"geekai/utils" "geekai/utils"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"io" "io"
@ -22,20 +23,30 @@ import (
// PlusClient MidJourney Plus ProxyClient // PlusClient MidJourney Plus ProxyClient
type PlusClient struct { type PlusClient struct {
Config types.MjPlusConfig Config types.MjPlusConfig
apiURL string apiURL string
client *req.Client client *req.Client
licenseService *service.LicenseService
} }
func NewPlusClient(config types.MjPlusConfig) *PlusClient { func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
return &PlusClient{ return &PlusClient{
Config: config, Config: config,
apiURL: config.ApiURL, apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"), client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
} }
} }
func (c *PlusClient) preCheck() error {
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" { if task.NegPrompt != "" {
@ -79,6 +90,10 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
// Blend 融图 // Blend 融图
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL) logger.Info("API URL: ", apiURL)
body := ImageReq{ body := ImageReq{
@ -118,6 +133,10 @@ func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
// SwapFace 换脸 // SwapFace 换脸
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) { func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码 // 生成图片 Base64 编码
if len(task.ImgArr) != 2 { if len(task.ImgArr) != 2 {
@ -167,6 +186,10 @@ func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
// Upscale 放大指定的图片 // Upscale 放大指定的图片
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
body := map[string]string{ body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId, "taskId": task.MessageId,
@ -194,6 +217,10 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 // Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
body := map[string]string{ body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId, "taskId": task.MessageId,

View File

@ -62,13 +62,8 @@ func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfig
if config.Enabled == false { if config.Enabled == false {
continue continue
} }
err := p.licenseService.IsValidApiURL(config.ApiURL)
if err != nil {
logger.Errorf("创建 MJ-PLUS 服务失败:%v", err)
continue
}
cli := NewPlusClient(config) cli := NewPlusClient(config, p.licenseService)
name := fmt.Sprintf("mj-plus-service-%d", k) name := fmt.Sprintf("mj-plus-service-%d", k)
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() { go func() {

View File

@ -108,7 +108,13 @@ func (s *Service) Run() {
} }
if err != nil || (res.Code != 1 && res.Code != 22) { if err != nil || (res.Code != 1 && res.Code != 22) {
errMsg := fmt.Sprintf("%v,%s", err, res.Description) var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
}
logger.Error("绘画任务执行失败:", errMsg) logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1 job.Progress = -1
job.ErrMsg = errMsg job.ErrMsg = errMsg