diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index bab51923..3dfbe6c0 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -146,7 +146,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { Model: "dall-e-3", Prompt: prompt, N: 1, - Size: "1024x1024", + Size: task.Size, Style: task.Style, Quality: task.Quality, }). diff --git a/api/service/mj/plus_client.go b/api/service/mj/plus_client.go index 7f85fd61..beb8943c 100644 --- a/api/service/mj/plus_client.go +++ b/api/service/mj/plus_client.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "geekai/core/types" + "geekai/service" "geekai/utils" "github.com/imroc/req/v3" "io" @@ -22,20 +23,30 @@ import ( // PlusClient MidJourney Plus ProxyClient type PlusClient struct { - Config types.MjPlusConfig - apiURL string - client *req.Client + Config types.MjPlusConfig + apiURL string + client *req.Client + licenseService *service.LicenseService } -func NewPlusClient(config types.MjPlusConfig) *PlusClient { +func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient { return &PlusClient{ - Config: config, - apiURL: config.ApiURL, - client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"), + Config: config, + apiURL: config.ApiURL, + client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"), + licenseService: licenseService, } } +func (c *PlusClient) preCheck() error { + return c.licenseService.IsValidApiURL(c.Config.ApiURL) +} + func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) if task.NegPrompt != "" { @@ -79,6 +90,10 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { // Blend 融图 func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) logger.Info("API URL: ", apiURL) body := ImageReq{ @@ -118,6 +133,10 @@ func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { // SwapFace 换脸 func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) // 生成图片 Base64 编码 if len(task.ImgArr) != 2 { @@ -167,6 +186,10 @@ func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) { // Upscale 放大指定的图片 func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + body := map[string]string{ "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), "taskId": task.MessageId, @@ -194,6 +217,10 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { // Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + body := map[string]string{ "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), "taskId": task.MessageId, diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index b28a3d2d..ddddd280 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -62,13 +62,8 @@ func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfig if config.Enabled == false { continue } - err := p.licenseService.IsValidApiURL(config.ApiURL) - if err != nil { - logger.Errorf("创建 MJ-PLUS 服务失败:%v", err) - continue - } - cli := NewPlusClient(config) + cli := NewPlusClient(config, p.licenseService) name := fmt.Sprintf("mj-plus-service-%d", k) plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) go func() { diff --git a/api/service/mj/service.go b/api/service/mj/service.go index e72d7476..baccd281 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -108,7 +108,13 @@ func (s *Service) Run() { } if err != nil || (res.Code != 1 && res.Code != 22) { - errMsg := fmt.Sprintf("%v,%s", err, res.Description) + var errMsg string + if err != nil { + errMsg = err.Error() + } else { + errMsg = fmt.Sprintf("%v,%s", err, res.Description) + } + logger.Error("绘画任务执行失败:", errMsg) job.Progress = -1 job.ErrMsg = errMsg