mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 11:13:42 +08:00
refactor midjourney service, use api key in database
This commit is contained in:
@@ -7,15 +7,28 @@ package mj
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import "geekai/core/types"
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
type Client interface {
|
||||
Imagine(task types.MjTask) (ImageRes, error)
|
||||
Blend(task types.MjTask) (ImageRes, error)
|
||||
SwapFace(task types.MjTask) (ImageRes, error)
|
||||
Upscale(task types.MjTask) (ImageRes, error)
|
||||
Variation(task types.MjTask) (ImageRes, error)
|
||||
QueryTask(taskId string) (QueryRes, error)
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Client MidJourney client
|
||||
type Client struct {
|
||||
client *req.Client
|
||||
licenseService *service.LicenseService
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
type ImageReq struct {
|
||||
@@ -33,7 +46,8 @@ type ImageRes struct {
|
||||
Description string `json:"description"`
|
||||
Properties struct {
|
||||
} `json:"properties"`
|
||||
Result string `json:"result"`
|
||||
Result string `json:"result"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
type ErrRes struct {
|
||||
@@ -66,3 +80,184 @@ type QueryRes struct {
|
||||
Status string `json:"status"`
|
||||
SubmitTime int `json:"submitTime"`
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client {
|
||||
return &Client{
|
||||
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
|
||||
licenseService: licenseService,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode)
|
||||
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
|
||||
if task.NegPrompt != "" {
|
||||
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
|
||||
}
|
||||
body := ImageReq{
|
||||
BotType: "MID_JOURNEY",
|
||||
Prompt: prompt,
|
||||
Base64Array: make([]string, 0),
|
||||
}
|
||||
// 生成图片 Base64 编码
|
||||
if len(task.ImgArr) > 0 {
|
||||
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
} else {
|
||||
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||
}
|
||||
|
||||
}
|
||||
return c.doRequest(body, apiPath, task.ChannelId)
|
||||
}
|
||||
|
||||
// Blend 融图
|
||||
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
||||
apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode)
|
||||
body := ImageReq{
|
||||
BotType: "MID_JOURNEY",
|
||||
Dimensions: "SQUARE",
|
||||
Base64Array: make([]string, 0),
|
||||
}
|
||||
// 生成图片 Base64 编码
|
||||
if len(task.ImgArr) > 0 {
|
||||
for _, imgURL := range task.ImgArr {
|
||||
imageData, err := utils.DownloadImage(imgURL, "")
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
} else {
|
||||
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||
}
|
||||
}
|
||||
}
|
||||
return c.doRequest(body, apiPath, task.ChannelId)
|
||||
}
|
||||
|
||||
// SwapFace 换脸
|
||||
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||
apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode)
|
||||
// 生成图片 Base64 编码
|
||||
if len(task.ImgArr) != 2 {
|
||||
return ImageRes{}, errors.New("参数错误,必须上传2张图片")
|
||||
}
|
||||
var sourceBase64 string
|
||||
var targetBase64 string
|
||||
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
} else {
|
||||
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||
}
|
||||
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
} else {
|
||||
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||
}
|
||||
|
||||
body := gin.H{
|
||||
"sourceBase64": sourceBase64,
|
||||
"targetBase64": targetBase64,
|
||||
"accountFilter": gin.H{
|
||||
"instanceId": "",
|
||||
},
|
||||
"state": "",
|
||||
}
|
||||
return c.doRequest(body, apiPath, task.ChannelId)
|
||||
}
|
||||
|
||||
// Upscale 放大指定的图片
|
||||
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
|
||||
body := map[string]string{
|
||||
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
||||
"taskId": task.MessageId,
|
||||
}
|
||||
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
|
||||
return c.doRequest(body, apiPath, task.ChannelId)
|
||||
}
|
||||
|
||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
||||
body := map[string]string{
|
||||
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
||||
"taskId": task.MessageId,
|
||||
}
|
||||
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
|
||||
|
||||
return c.doRequest(body, apiPath, task.ChannelId)
|
||||
}
|
||||
|
||||
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
|
||||
if channel != "" {
|
||||
session = session.Where("api_url", channel)
|
||||
}
|
||||
|
||||
var apiKey model.ApiKey
|
||||
err := session.Order("last_used_at ASC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
|
||||
}
|
||||
|
||||
if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil {
|
||||
return ImageRes{}, err
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath)
|
||||
logger.Info("API URL: ", apiURL)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if r != nil {
|
||||
errStr, _ := io.ReadAll(r.Body)
|
||||
logger.Error("请求 API 出错:", string(errStr))
|
||||
errMsg = errMsg + " " + string(errStr)
|
||||
}
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
}
|
||||
|
||||
// update the api key last used time
|
||||
if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil {
|
||||
logger.Error("update api key last used time error: ", err)
|
||||
}
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId)
|
||||
var res QueryRes
|
||||
r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
|
||||
if err != nil {
|
||||
return QueryRes{}, err
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return QueryRes{}, errors.New("error status:" + r.Status)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user