mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			251 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			251 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package mj
 | 
						||
 | 
						||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
						||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
						||
// * Use of this source code is governed by a Apache-2.0 license
 | 
						||
// * that can be found in the LICENSE file.
 | 
						||
// * @Author yangjian102621@163.com
 | 
						||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
						||
 | 
						||
import (
 | 
						||
	"encoding/base64"
 | 
						||
	"errors"
 | 
						||
	"fmt"
 | 
						||
	"geekai/core/types"
 | 
						||
	logger2 "geekai/logger"
 | 
						||
	"geekai/service"
 | 
						||
	"geekai/store/model"
 | 
						||
	"geekai/utils"
 | 
						||
	"github.com/imroc/req/v3"
 | 
						||
	"gorm.io/gorm"
 | 
						||
	"io"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/gin-gonic/gin"
 | 
						||
)
 | 
						||
 | 
						||
// Client MidJourney client
 | 
						||
type Client struct {
 | 
						||
	client         *req.Client
 | 
						||
	licenseService *service.LicenseService
 | 
						||
	db             *gorm.DB
 | 
						||
}
 | 
						||
 | 
						||
type ImageReq struct {
 | 
						||
	BotType       string      `json:"botType,omitempty"`
 | 
						||
	Prompt        string      `json:"prompt,omitempty"`
 | 
						||
	Dimensions    string      `json:"dimensions,omitempty"`
 | 
						||
	Base64Array   []string    `json:"base64Array,omitempty"`
 | 
						||
	AccountFilter interface{} `json:"accountFilter,omitempty"`
 | 
						||
	NotifyHook    string      `json:"notifyHook,omitempty"`
 | 
						||
	State         string      `json:"state,omitempty"`
 | 
						||
}
 | 
						||
 | 
						||
type ImageRes struct {
 | 
						||
	Code        int    `json:"code"`
 | 
						||
	Description string `json:"description"`
 | 
						||
	Properties  struct {
 | 
						||
	} `json:"properties"`
 | 
						||
	Result  string `json:"result"`
 | 
						||
	Channel string `json:"channel,omitempty"`
 | 
						||
}
 | 
						||
 | 
						||
type QueryRes struct {
 | 
						||
	Action  string `json:"action"`
 | 
						||
	Buttons []struct {
 | 
						||
		CustomId string `json:"customId"`
 | 
						||
		Emoji    string `json:"emoji"`
 | 
						||
		Label    string `json:"label"`
 | 
						||
		Style    int    `json:"style"`
 | 
						||
		Type     int    `json:"type"`
 | 
						||
	} `json:"buttons"`
 | 
						||
	Description string `json:"description"`
 | 
						||
	FailReason  string `json:"failReason"`
 | 
						||
	FinishTime  int    `json:"finishTime"`
 | 
						||
	Id          string `json:"id"`
 | 
						||
	ImageUrl    string `json:"imageUrl"`
 | 
						||
	Progress    string `json:"progress"`
 | 
						||
	Prompt      string `json:"prompt"`
 | 
						||
	PromptEn    string `json:"promptEn"`
 | 
						||
	Properties  struct {
 | 
						||
	} `json:"properties"`
 | 
						||
	StartTime  int    `json:"startTime"`
 | 
						||
	State      string `json:"state"`
 | 
						||
	Status     string `json:"status"`
 | 
						||
	SubmitTime int    `json:"submitTime"`
 | 
						||
}
 | 
						||
 | 
						||
var logger = logger2.GetLogger()
 | 
						||
 | 
						||
func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client {
 | 
						||
	return &Client{
 | 
						||
		client:         req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
 | 
						||
		licenseService: licenseService,
 | 
						||
		db:             db,
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
 | 
						||
	apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode)
 | 
						||
	prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
 | 
						||
	if task.NegPrompt != "" {
 | 
						||
		prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
 | 
						||
	}
 | 
						||
	body := ImageReq{
 | 
						||
		BotType:     "MID_JOURNEY",
 | 
						||
		Prompt:      prompt,
 | 
						||
		Base64Array: make([]string, 0),
 | 
						||
	}
 | 
						||
	// 生成图片 Base64 编码
 | 
						||
	if len(task.ImgArr) > 0 {
 | 
						||
		imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
						||
		if err != nil {
 | 
						||
			logger.Error("error with download image: ", err)
 | 
						||
		} else {
 | 
						||
			body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
						||
		}
 | 
						||
 | 
						||
	}
 | 
						||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
						||
}
 | 
						||
 | 
						||
// Blend 融图
 | 
						||
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
 | 
						||
	apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode)
 | 
						||
	body := ImageReq{
 | 
						||
		BotType:     "MID_JOURNEY",
 | 
						||
		Dimensions:  "SQUARE",
 | 
						||
		Base64Array: make([]string, 0),
 | 
						||
	}
 | 
						||
	// 生成图片 Base64 编码
 | 
						||
	if len(task.ImgArr) > 0 {
 | 
						||
		for _, imgURL := range task.ImgArr {
 | 
						||
			imageData, err := utils.DownloadImage(imgURL, "")
 | 
						||
			if err != nil {
 | 
						||
				logger.Error("error with download image: ", err)
 | 
						||
			} else {
 | 
						||
				body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
						||
			}
 | 
						||
		}
 | 
						||
	}
 | 
						||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
						||
}
 | 
						||
 | 
						||
// SwapFace 换脸
 | 
						||
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
						||
	apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode)
 | 
						||
	// 生成图片 Base64 编码
 | 
						||
	if len(task.ImgArr) != 2 {
 | 
						||
		return ImageRes{}, errors.New("参数错误,必须上传2张图片")
 | 
						||
	}
 | 
						||
	var sourceBase64 string
 | 
						||
	var targetBase64 string
 | 
						||
	imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
						||
	if err != nil {
 | 
						||
		logger.Error("error with download image: ", err)
 | 
						||
	} else {
 | 
						||
		sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
						||
	}
 | 
						||
	imageData, err = utils.DownloadImage(task.ImgArr[1], "")
 | 
						||
	if err != nil {
 | 
						||
		logger.Error("error with download image: ", err)
 | 
						||
	} else {
 | 
						||
		targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
						||
	}
 | 
						||
 | 
						||
	body := gin.H{
 | 
						||
		"sourceBase64": sourceBase64,
 | 
						||
		"targetBase64": targetBase64,
 | 
						||
		"accountFilter": gin.H{
 | 
						||
			"instanceId": "",
 | 
						||
		},
 | 
						||
		"state": "",
 | 
						||
	}
 | 
						||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
						||
}
 | 
						||
 | 
						||
// Upscale 放大指定的图片
 | 
						||
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
 | 
						||
	body := map[string]string{
 | 
						||
		"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
 | 
						||
		"taskId":   task.MessageId,
 | 
						||
	}
 | 
						||
	apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
 | 
						||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
						||
}
 | 
						||
 | 
						||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
						||
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
 | 
						||
	body := map[string]string{
 | 
						||
		"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
 | 
						||
		"taskId":   task.MessageId,
 | 
						||
	}
 | 
						||
	apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
 | 
						||
 | 
						||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
						||
}
 | 
						||
 | 
						||
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
 | 
						||
	var res ImageRes
 | 
						||
	session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
 | 
						||
	if channel != "" {
 | 
						||
		session = session.Where("api_url", channel)
 | 
						||
	}
 | 
						||
 | 
						||
	var apiKey model.ApiKey
 | 
						||
	err := session.Order("last_used_at ASC").First(&apiKey).Error
 | 
						||
	if err != nil {
 | 
						||
		return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
 | 
						||
	}
 | 
						||
 | 
						||
	if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil {
 | 
						||
		return ImageRes{}, err
 | 
						||
	}
 | 
						||
 | 
						||
	apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath)
 | 
						||
	logger.Info("API URL: ", apiURL)
 | 
						||
	r, err := req.C().R().
 | 
						||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
						||
		SetBody(body).
 | 
						||
		SetSuccessResult(&res).
 | 
						||
		Post(apiURL)
 | 
						||
	if err != nil {
 | 
						||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
						||
	}
 | 
						||
 | 
						||
	if r.IsErrorState() {
 | 
						||
		errMsg, _ := io.ReadAll(r.Body)
 | 
						||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
 | 
						||
	}
 | 
						||
 | 
						||
	// update the api key last used time
 | 
						||
	if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil {
 | 
						||
		logger.Error("update api key last used time error: ", err)
 | 
						||
	}
 | 
						||
	res.Channel = apiKey.ApiURL
 | 
						||
	return res, nil
 | 
						||
}
 | 
						||
 | 
						||
func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) {
 | 
						||
	var apiKey model.ApiKey
 | 
						||
	err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error
 | 
						||
	if err != nil {
 | 
						||
		return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
 | 
						||
	}
 | 
						||
	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId)
 | 
						||
	var res QueryRes
 | 
						||
	r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
						||
		SetSuccessResult(&res).
 | 
						||
		Get(apiURL)
 | 
						||
 | 
						||
	if err != nil {
 | 
						||
		return QueryRes{}, err
 | 
						||
	}
 | 
						||
 | 
						||
	if r.IsErrorState() {
 | 
						||
		return QueryRes{}, errors.New("error status:" + r.Status)
 | 
						||
	}
 | 
						||
 | 
						||
	return res, nil
 | 
						||
}
 |