mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-21 18:56:38 +08:00
check if the api url in whitelist for mj plus client
This commit is contained in:
parent
5742b40aee
commit
6944a32ff3
@ -146,7 +146,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
|||||||
Model: "dall-e-3",
|
Model: "dall-e-3",
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
N: 1,
|
N: 1,
|
||||||
Size: "1024x1024",
|
Size: task.Size,
|
||||||
Style: task.Style,
|
Style: task.Style,
|
||||||
Quality: task.Quality,
|
Quality: task.Quality,
|
||||||
}).
|
}).
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geekai/core/types"
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
"geekai/utils"
|
"geekai/utils"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"io"
|
"io"
|
||||||
@ -22,20 +23,30 @@ import (
|
|||||||
|
|
||||||
// PlusClient MidJourney Plus ProxyClient
|
// PlusClient MidJourney Plus ProxyClient
|
||||||
type PlusClient struct {
|
type PlusClient struct {
|
||||||
Config types.MjPlusConfig
|
Config types.MjPlusConfig
|
||||||
apiURL string
|
apiURL string
|
||||||
client *req.Client
|
client *req.Client
|
||||||
|
licenseService *service.LicenseService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
|
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
|
||||||
return &PlusClient{
|
return &PlusClient{
|
||||||
Config: config,
|
Config: config,
|
||||||
apiURL: config.ApiURL,
|
apiURL: config.ApiURL,
|
||||||
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
|
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
|
||||||
|
licenseService: licenseService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *PlusClient) preCheck() error {
|
||||||
|
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
||||||
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
||||||
if task.NegPrompt != "" {
|
if task.NegPrompt != "" {
|
||||||
@ -79,6 +90,10 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
|||||||
|
|
||||||
// Blend 融图
|
// Blend 融图
|
||||||
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
||||||
logger.Info("API URL: ", apiURL)
|
logger.Info("API URL: ", apiURL)
|
||||||
body := ImageReq{
|
body := ImageReq{
|
||||||
@ -118,6 +133,10 @@ func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
|||||||
|
|
||||||
// SwapFace 换脸
|
// SwapFace 换脸
|
||||||
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
||||||
// 生成图片 Base64 编码
|
// 生成图片 Base64 编码
|
||||||
if len(task.ImgArr) != 2 {
|
if len(task.ImgArr) != 2 {
|
||||||
@ -167,6 +186,10 @@ func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
|||||||
|
|
||||||
// Upscale 放大指定的图片
|
// Upscale 放大指定的图片
|
||||||
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
body := map[string]string{
|
body := map[string]string{
|
||||||
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
||||||
"taskId": task.MessageId,
|
"taskId": task.MessageId,
|
||||||
@ -194,6 +217,10 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
|
|||||||
|
|
||||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||||
|
if err := c.preCheck(); err != nil {
|
||||||
|
return ImageRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
body := map[string]string{
|
body := map[string]string{
|
||||||
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
||||||
"taskId": task.MessageId,
|
"taskId": task.MessageId,
|
||||||
|
@ -62,13 +62,8 @@ func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfig
|
|||||||
if config.Enabled == false {
|
if config.Enabled == false {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err := p.licenseService.IsValidApiURL(config.ApiURL)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("创建 MJ-PLUS 服务失败:%v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
cli := NewPlusClient(config)
|
cli := NewPlusClient(config, p.licenseService)
|
||||||
name := fmt.Sprintf("mj-plus-service-%d", k)
|
name := fmt.Sprintf("mj-plus-service-%d", k)
|
||||||
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
|
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -108,7 +108,13 @@ func (s *Service) Run() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||||
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
|
var errMsg string
|
||||||
|
if err != nil {
|
||||||
|
errMsg = err.Error()
|
||||||
|
} else {
|
||||||
|
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Error("绘画任务执行失败:", errMsg)
|
logger.Error("绘画任务执行失败:", errMsg)
|
||||||
job.Progress = -1
|
job.Progress = -1
|
||||||
job.ErrMsg = errMsg
|
job.ErrMsg = errMsg
|
||||||
|
Loading…
Reference in New Issue
Block a user