check if the api url in whitelist for mj plus client

This commit is contained in:
RockYang 2024-05-22 11:47:04 +08:00
parent 5742b40aee
commit 6944a32ff3
4 changed files with 43 additions and 15 deletions

View File

@ -146,7 +146,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}).

View File

@ -12,6 +12,7 @@ import (
"errors"
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
@ -22,20 +23,30 @@ import (
// PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
Config types.MjPlusConfig
apiURL string
client *req.Client
licenseService *service.LicenseService
}
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
}
}
func (c *PlusClient) preCheck() error {
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
@ -79,6 +90,10 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
// Blend 融图
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
body := ImageReq{
@ -118,6 +133,10 @@ func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
// SwapFace 换脸
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
@ -167,6 +186,10 @@ func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
// Upscale 放大指定的图片
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
@ -194,6 +217,10 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,

View File

@ -62,13 +62,8 @@ func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfig
if config.Enabled == false {
continue
}
err := p.licenseService.IsValidApiURL(config.ApiURL)
if err != nil {
logger.Errorf("创建 MJ-PLUS 服务失败:%v", err)
continue
}
cli := NewPlusClient(config)
cli := NewPlusClient(config, p.licenseService)
name := fmt.Sprintf("mj-plus-service-%d", k)
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {

View File

@ -108,7 +108,13 @@ func (s *Service) Run() {
}
if err != nil || (res.Code != 1 && res.Code != 22) {
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = -1
job.ErrMsg = errMsg