mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	feat: change midjourney origin implements, replace midjourney bot with midjourney-proxy
This commit is contained in:
		@@ -1,4 +1,13 @@
 | 
			
		||||
# 更新日志
 | 
			
		||||
## v4.0.1
 | 
			
		||||
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion 发行版,稳定性更强一些
 | 
			
		||||
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容 MJ-Plus 中转
 | 
			
		||||
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
 | 
			
		||||
* 功能新增:支持前端菜单可以配置
 | 
			
		||||
* 功能优化:手机端支持免登录预览功能
 | 
			
		||||
* 功能新增:手机端支持 Stable-Diffusion 绘画
 | 
			
		||||
* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换
 | 
			
		||||
 | 
			
		||||
## v4.0.0
 | 
			
		||||
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
 | 
			
		||||
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力...
 | 
			
		||||
 
 | 
			
		||||
@@ -5,22 +5,22 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type AppConfig struct {
 | 
			
		||||
	Path          string `toml:"-"`
 | 
			
		||||
	Listen        string
 | 
			
		||||
	Session       Session
 | 
			
		||||
	AdminSession  Session
 | 
			
		||||
	ProxyURL      string
 | 
			
		||||
	MysqlDns      string                  // mysql 连接地址
 | 
			
		||||
	StaticDir     string                  // 静态资源目录
 | 
			
		||||
	StaticUrl     string                  // 静态资源 URL
 | 
			
		||||
	Redis         RedisConfig             // redis 连接信息
 | 
			
		||||
	ApiConfig     ChatPlusApiConfig       // ChatPlus API authorization configs
 | 
			
		||||
	SMS           SMSConfig               // send mobile message config
 | 
			
		||||
	OSS           OSSConfig               // OSS config
 | 
			
		||||
	MjConfigs     []MidJourneyConfig      // mj AI draw service pool
 | 
			
		||||
	MjPlusConfigs []MidJourneyPlusConfig  // MJ plus config
 | 
			
		||||
	WeChatBot     bool                    // 是否启用微信机器人
 | 
			
		||||
	SdConfigs     []StableDiffusionConfig // sd AI draw service pool
 | 
			
		||||
	Path           string `toml:"-"`
 | 
			
		||||
	Listen         string
 | 
			
		||||
	Session        Session
 | 
			
		||||
	AdminSession   Session
 | 
			
		||||
	ProxyURL       string
 | 
			
		||||
	MysqlDns       string                  // mysql 连接地址
 | 
			
		||||
	StaticDir      string                  // 静态资源目录
 | 
			
		||||
	StaticUrl      string                  // 静态资源 URL
 | 
			
		||||
	Redis          RedisConfig             // redis 连接信息
 | 
			
		||||
	ApiConfig      ChatPlusApiConfig       // ChatPlus API authorization configs
 | 
			
		||||
	SMS            SMSConfig               // send mobile message config
 | 
			
		||||
	OSS            OSSConfig               // OSS config
 | 
			
		||||
	MjProxyConfigs []MjProxyConfig         // MJ proxy config
 | 
			
		||||
	MjPlusConfigs  []MjPlusConfig          // MJ plus config
 | 
			
		||||
	WeChatBot      bool                    // 是否启用微信机器人
 | 
			
		||||
	SdConfigs      []StableDiffusionConfig // sd AI draw service pool
 | 
			
		||||
 | 
			
		||||
	XXLConfig     XXLConfig
 | 
			
		||||
	AlipayConfig  AlipayConfig
 | 
			
		||||
@@ -43,16 +43,11 @@ type ChatPlusApiConfig struct {
 | 
			
		||||
	Token  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidJourneyConfig struct {
 | 
			
		||||
	Enabled        bool
 | 
			
		||||
	UserToken      string
 | 
			
		||||
	BotToken       string
 | 
			
		||||
	GuildId        string // Server ID
 | 
			
		||||
	ChanelId       string // Chanel ID
 | 
			
		||||
	UseCDN         bool
 | 
			
		||||
	ImgCdnURL      string // 图片反代加速地址
 | 
			
		||||
	DiscordAPI     string
 | 
			
		||||
	DiscordGateway string
 | 
			
		||||
type MjProxyConfig struct {
 | 
			
		||||
	Enabled bool
 | 
			
		||||
	ApiURL  string // api 地址
 | 
			
		||||
	Mode    string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StableDiffusionConfig struct {
 | 
			
		||||
@@ -62,12 +57,11 @@ type StableDiffusionConfig struct {
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidJourneyPlusConfig struct {
 | 
			
		||||
	Enabled   bool   // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
 | 
			
		||||
	ApiURL    string // api 地址
 | 
			
		||||
	Mode      string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey    string
 | 
			
		||||
	NotifyURL string // 任务进度更新回调地址
 | 
			
		||||
type MjPlusConfig struct {
 | 
			
		||||
	Enabled bool   // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
 | 
			
		||||
	ApiURL  string // api 地址
 | 
			
		||||
	Mode    string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AlipayConfig struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/service/mj/plus"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
@@ -454,27 +453,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Notify MidJourney Plus 服务任务回调处理
 | 
			
		||||
func (h *MidJourneyHandler) Notify(c *gin.Context) {
 | 
			
		||||
	var data plus.CBReq
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		logger.Error("非法任务回调:%+v", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	err := h.pool.Notify(data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
	} else {
 | 
			
		||||
		userId := h.GetLoginUserId(c)
 | 
			
		||||
		client := h.pool.Clients.Get(userId)
 | 
			
		||||
		if client != nil {
 | 
			
		||||
			_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Publish 发布图片到画廊显示
 | 
			
		||||
func (h *MidJourneyHandler) Publish(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -252,7 +252,6 @@ func main() {
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.GET("imgWall", h.ImgWall)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
			group.POST("notify", h.Notify)
 | 
			
		||||
			group.POST("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,233 +0,0 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	discordgo "github.com/bg5t/mydiscordgo"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MidJourney 机器人
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type Bot struct {
 | 
			
		||||
	config  types.MidJourneyConfig
 | 
			
		||||
	bot     *discordgo.Session
 | 
			
		||||
	name    string
 | 
			
		||||
	service *Service
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
 | 
			
		||||
	bot, err := discordgo.New("Bot " + config.BotToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// use CDN reverse proxy
 | 
			
		||||
	if config.UseCDN {
 | 
			
		||||
		discordgo.SetEndpointDiscord(config.DiscordAPI)
 | 
			
		||||
		discordgo.SetEndpointCDN("https://cdn.discordapp.com")
 | 
			
		||||
		discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
 | 
			
		||||
		bot.MjGateway = config.DiscordGateway + "/"
 | 
			
		||||
	} else { // use proxy
 | 
			
		||||
		discordgo.SetEndpointDiscord("https://discord.com")
 | 
			
		||||
		discordgo.SetEndpointCDN("https://cdn.discordapp.com")
 | 
			
		||||
		discordgo.SetEndpointStatus("https://discord.com/api/v2/")
 | 
			
		||||
		bot.MjGateway = "wss://gateway.discord.gg"
 | 
			
		||||
 | 
			
		||||
		if proxy != "" {
 | 
			
		||||
			proxy, _ := url.Parse(proxy)
 | 
			
		||||
			bot.Client = &http.Client{
 | 
			
		||||
				Transport: &http.Transport{
 | 
			
		||||
					Proxy: http.ProxyURL(proxy),
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
			bot.Dialer = &websocket.Dialer{
 | 
			
		||||
				Proxy: http.ProxyURL(proxy),
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Bot{
 | 
			
		||||
		config:  config,
 | 
			
		||||
		bot:     bot,
 | 
			
		||||
		name:    name,
 | 
			
		||||
		service: service,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) Run() error {
 | 
			
		||||
	b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
 | 
			
		||||
	b.bot.AddHandler(b.messageCreate)
 | 
			
		||||
	b.bot.AddHandler(b.messageUpdate)
 | 
			
		||||
 | 
			
		||||
	logger.Infof("Starting MidJourney %s", b.name)
 | 
			
		||||
	err := b.bot.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	logger.Infof("Starting MidJourney %s successfully!", b.name)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TaskStatus string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	Start    = TaskStatus("Started")
 | 
			
		||||
	Running  = TaskStatus("Running")
 | 
			
		||||
	Stopped  = TaskStatus("Stopped")
 | 
			
		||||
	Finished = TaskStatus("Finished")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Image struct {
 | 
			
		||||
	URL      string `json:"url"`
 | 
			
		||||
	ProxyURL string `json:"proxy_url"`
 | 
			
		||||
	Filename string `json:"filename"`
 | 
			
		||||
	Width    int    `json:"width"`
 | 
			
		||||
	Height   int    `json:"height"`
 | 
			
		||||
	Size     int    `json:"size"`
 | 
			
		||||
	Hash     string `json:"hash"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
 | 
			
		||||
	// ignore messages for other channels
 | 
			
		||||
	if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// ignore messages for self
 | 
			
		||||
	if m.Author == nil || m.Author.ID == s.State.User.ID {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("CREATE: %s", utils.JsonEncode(m))
 | 
			
		||||
	var referenceId = ""
 | 
			
		||||
	if m.ReferencedMessage != nil {
 | 
			
		||||
		referenceId = m.ReferencedMessage.ID
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
 | 
			
		||||
		// parse content
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			ChannelId:   m.ChannelID,
 | 
			
		||||
			MessageId:   m.ID,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Prompt:      extractPrompt(m.Content),
 | 
			
		||||
			Content:     m.Content,
 | 
			
		||||
			Progress:    0,
 | 
			
		||||
			Status:      Start}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
			
		||||
	// ignore messages for other channels
 | 
			
		||||
	if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// ignore messages for self
 | 
			
		||||
	if m.Author == nil || m.Author.ID == s.State.User.ID {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
 | 
			
		||||
 | 
			
		||||
	var referenceId = ""
 | 
			
		||||
	if m.ReferencedMessage != nil {
 | 
			
		||||
		referenceId = m.ReferencedMessage.ID
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(m.Content, "(Stopped)") {
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			ChannelId:   m.ChannelID,
 | 
			
		||||
			MessageId:   m.ID,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Prompt:      extractPrompt(m.Content),
 | 
			
		||||
			Content:     m.Content,
 | 
			
		||||
			Progress:    extractProgress(m.Content),
 | 
			
		||||
			Status:      Stopped}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
 | 
			
		||||
	progress := extractProgress(content)
 | 
			
		||||
	var status TaskStatus
 | 
			
		||||
	if progress == 100 {
 | 
			
		||||
		status = Finished
 | 
			
		||||
	} else {
 | 
			
		||||
		status = Running
 | 
			
		||||
	}
 | 
			
		||||
	for _, attachment := range attachments {
 | 
			
		||||
		if attachment.Width == 0 || attachment.Height == 0 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		image := Image{
 | 
			
		||||
			URL:      attachment.URL,
 | 
			
		||||
			Height:   attachment.Height,
 | 
			
		||||
			ProxyURL: attachment.ProxyURL,
 | 
			
		||||
			Width:    attachment.Width,
 | 
			
		||||
			Size:     attachment.Size,
 | 
			
		||||
			Filename: attachment.Filename,
 | 
			
		||||
			Hash:     extractHashFromFilename(attachment.Filename),
 | 
			
		||||
		}
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			ChannelId:   channelId,
 | 
			
		||||
			MessageId:   messageId,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Image:       image,
 | 
			
		||||
			Prompt:      extractPrompt(content),
 | 
			
		||||
			Content:     content,
 | 
			
		||||
			Progress:    progress,
 | 
			
		||||
			Status:      status,
 | 
			
		||||
		}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		break // only get one image
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// extract prompt from string
 | 
			
		||||
func extractPrompt(input string) string {
 | 
			
		||||
	pattern := `\*\*(.*?)\*\*`
 | 
			
		||||
	re := regexp.MustCompile(pattern)
 | 
			
		||||
	matches := re.FindStringSubmatch(input)
 | 
			
		||||
	if len(matches) > 1 {
 | 
			
		||||
		return strings.TrimSpace(matches[1])
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func extractProgress(input string) int {
 | 
			
		||||
	pattern := `\((\d+)\%\)`
 | 
			
		||||
	re := regexp.MustCompile(pattern)
 | 
			
		||||
	matches := re.FindStringSubmatch(input)
 | 
			
		||||
	if len(matches) > 1 {
 | 
			
		||||
		return utils.IntValue(matches[1], 0)
 | 
			
		||||
	}
 | 
			
		||||
	return 100
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func extractHashFromFilename(filename string) string {
 | 
			
		||||
	if !strings.HasSuffix(filename, ".png") {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	index := strings.LastIndex(filename, "_")
 | 
			
		||||
	if index != -1 {
 | 
			
		||||
		return filename[index+1 : len(filename)-4]
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
@@ -1,159 +1,61 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"time"
 | 
			
		||||
import "chatplus/core/types"
 | 
			
		||||
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MidJourney client
 | 
			
		||||
 | 
			
		||||
type Client struct {
 | 
			
		||||
	client *req.Client
 | 
			
		||||
	Config types.MidJourneyConfig
 | 
			
		||||
	apiURL string
 | 
			
		||||
type Client interface {
 | 
			
		||||
	Imagine(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Blend(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	SwapFace(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Upscale(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Variation(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	QueryTask(taskId string) (QueryRes, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
 | 
			
		||||
	client := req.C().SetTimeout(10 * time.Second)
 | 
			
		||||
	var apiURL string
 | 
			
		||||
	// set proxy URL
 | 
			
		||||
	if config.UseCDN {
 | 
			
		||||
		apiURL = config.DiscordAPI + "/api/v9/interactions"
 | 
			
		||||
	} else {
 | 
			
		||||
		apiURL = "https://discord.com/api/v9/interactions"
 | 
			
		||||
		if proxy != "" {
 | 
			
		||||
			client.SetProxyURL(proxy)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Client{client: client, Config: config, apiURL: apiURL}
 | 
			
		||||
type ImageReq struct {
 | 
			
		||||
	BotType       string      `json:"botType,omitempty"`
 | 
			
		||||
	Prompt        string      `json:"prompt,omitempty"`
 | 
			
		||||
	Dimensions    string      `json:"dimensions,omitempty"`
 | 
			
		||||
	Base64Array   []string    `json:"base64Array,omitempty"`
 | 
			
		||||
	AccountFilter interface{} `json:"accountFilter,omitempty"`
 | 
			
		||||
	NotifyHook    string      `json:"notifyHook,omitempty"`
 | 
			
		||||
	State         string      `json:"state,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) Imagine(task types.MjTask) error {
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          2,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.Config.GuildId,
 | 
			
		||||
		ChannelID:     c.Config.ChanelId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"version": "1166847114203123795",
 | 
			
		||||
			"id":      "938956540159881230",
 | 
			
		||||
			"name":    "imagine",
 | 
			
		||||
			"type":    "1",
 | 
			
		||||
			"options": []map[string]any{
 | 
			
		||||
				{
 | 
			
		||||
					"type":  3,
 | 
			
		||||
					"name":  "prompt",
 | 
			
		||||
					"value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt),
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"application_command": map[string]any{
 | 
			
		||||
				"id":                         "938956540159881230",
 | 
			
		||||
				"application_id":             ApplicationID,
 | 
			
		||||
				"version":                    "1118961510123847772",
 | 
			
		||||
				"default_permission":         true,
 | 
			
		||||
				"default_member_permissions": nil,
 | 
			
		||||
				"type":                       1,
 | 
			
		||||
				"nsfw":                       false,
 | 
			
		||||
				"name":                       "imagine",
 | 
			
		||||
				"description":                "Create images with Midjourney",
 | 
			
		||||
				"dm_permission":              true,
 | 
			
		||||
				"options": []map[string]any{
 | 
			
		||||
					{
 | 
			
		||||
						"type":        3,
 | 
			
		||||
						"name":        "prompt",
 | 
			
		||||
						"description": "The prompt to imagine",
 | 
			
		||||
						"required":    true,
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
				"attachments": []any{},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		Post(c.apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %w%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
type ImageRes struct {
 | 
			
		||||
	Code        int    `json:"code"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
	Result string `json:"result"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) Blend(task types.MjTask) error {
 | 
			
		||||
	return errors.New("function not implemented")
 | 
			
		||||
type ErrRes struct {
 | 
			
		||||
	Error struct {
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) SwapFace(task types.MjTask) error {
 | 
			
		||||
	return errors.New("function not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *Client) Upscale(task types.MjTask) error {
 | 
			
		||||
	flags := 0
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          3,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.Config.GuildId,
 | 
			
		||||
		ChannelID:     c.Config.ChanelId,
 | 
			
		||||
		MessageFlags:  flags,
 | 
			
		||||
		MessageID:     task.MessageId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"component_type": 2,
 | 
			
		||||
			"custom_id":      fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		},
 | 
			
		||||
		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res InteractionsResult
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		SetErrorResult(&res).
 | 
			
		||||
		Post(c.apiURL)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *Client) Variation(task types.MjTask) error {
 | 
			
		||||
	flags := 0
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          3,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.Config.GuildId,
 | 
			
		||||
		ChannelID:     c.Config.ChanelId,
 | 
			
		||||
		MessageFlags:  flags,
 | 
			
		||||
		MessageID:     task.MessageId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"component_type": 2,
 | 
			
		||||
			"custom_id":      fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		},
 | 
			
		||||
		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res InteractionsResult
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		SetErrorResult(&res).
 | 
			
		||||
		Post(c.apiURL)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
type QueryRes struct {
 | 
			
		||||
	Action  string `json:"action"`
 | 
			
		||||
	Buttons []struct {
 | 
			
		||||
		CustomId string `json:"customId"`
 | 
			
		||||
		Emoji    string `json:"emoji"`
 | 
			
		||||
		Label    string `json:"label"`
 | 
			
		||||
		Style    int    `json:"style"`
 | 
			
		||||
		Type     int    `json:"type"`
 | 
			
		||||
	} `json:"buttons"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
	FailReason  string `json:"failReason"`
 | 
			
		||||
	FinishTime  int    `json:"finishTime"`
 | 
			
		||||
	Id          string `json:"id"`
 | 
			
		||||
	ImageUrl    string `json:"imageUrl"`
 | 
			
		||||
	Progress    string `json:"progress"`
 | 
			
		||||
	Prompt      string `json:"prompt"`
 | 
			
		||||
	PromptEn    string `json:"promptEn"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
	StartTime  int    `json:"startTime"`
 | 
			
		||||
	State      string `json:"state"`
 | 
			
		||||
	Status     string `json:"status"`
 | 
			
		||||
	SubmitTime int    `json:"submitTime"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,194 +0,0 @@
 | 
			
		||||
package plus
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Service MJ 绘画服务
 | 
			
		||||
type Service struct {
 | 
			
		||||
	Name             string  // service Name
 | 
			
		||||
	Client           *Client // MJ Client
 | 
			
		||||
	taskQueue        *store.RedisQueue
 | 
			
		||||
	notifyQueue      *store.RedisQueue
 | 
			
		||||
	db               *gorm.DB
 | 
			
		||||
	maxHandleTaskNum int32             // max task number current service can handle
 | 
			
		||||
	HandledTaskNum   int32             // already handled task number
 | 
			
		||||
	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout
 | 
			
		||||
	taskTimeout      int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		Name:             name,
 | 
			
		||||
		db:               db,
 | 
			
		||||
		taskQueue:        taskQueue,
 | 
			
		||||
		notifyQueue:      notifyQueue,
 | 
			
		||||
		Client:           client,
 | 
			
		||||
		taskTimeout:      timeout,
 | 
			
		||||
		maxHandleTaskNum: maxTaskNum,
 | 
			
		||||
		taskStartTimes:   make(map[int]time.Time, 0),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Infof("Starting MidJourney job consumer for %s", s.Name)
 | 
			
		||||
	for {
 | 
			
		||||
		s.checkTasks()
 | 
			
		||||
		if !s.canHandleTask() {
 | 
			
		||||
			// current service is full, can not handle more task
 | 
			
		||||
			// waiting for running task finish
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var task types.MjTask
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// if it's reference message, check if it's this channel's  message
 | 
			
		||||
		//if task.ChannelId != "" && task.ChannelId != s.Name {
 | 
			
		||||
		//	logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
 | 
			
		||||
		//	s.taskQueue.RPush(task)
 | 
			
		||||
		//	time.Sleep(time.Second)
 | 
			
		||||
		//	continue
 | 
			
		||||
		//}
 | 
			
		||||
 | 
			
		||||
		logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
 | 
			
		||||
		var res ImageRes
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
			res, err = s.Client.Imagine(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskUpscale:
 | 
			
		||||
			res, err = s.Client.Upscale(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskVariation:
 | 
			
		||||
			res, err = s.Client.Variation(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskBlend:
 | 
			
		||||
			res, err = s.Client.Blend(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskSwapFace:
 | 
			
		||||
			res, err = s.Client.SwapFace(task)
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var job model.MidJourneyJob
 | 
			
		||||
		s.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
		if err != nil || (res.Code != 1 && res.Code != 22) {
 | 
			
		||||
			errMsg := fmt.Sprintf("%v,%s", err, res.Description)
 | 
			
		||||
			logger.Error("绘画任务执行失败:", errMsg)
 | 
			
		||||
			job.Progress = -1
 | 
			
		||||
			job.ErrMsg = errMsg
 | 
			
		||||
			// update the task progress
 | 
			
		||||
			s.db.Updates(&job)
 | 
			
		||||
			// 任务失败,通知前端
 | 
			
		||||
			s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("任务提交成功:%+v", res)
 | 
			
		||||
		// lock the task until the execute timeout
 | 
			
		||||
		s.taskStartTimes[int(task.Id)] = time.Now()
 | 
			
		||||
		atomic.AddInt32(&s.HandledTaskNum, 1)
 | 
			
		||||
		// 更新任务 ID/频道
 | 
			
		||||
		job.TaskId = res.Result
 | 
			
		||||
		job.ChannelId = s.Name
 | 
			
		||||
		s.db.Updates(&job)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// check if current service instance can handle more task
 | 
			
		||||
func (s *Service) canHandleTask() bool {
 | 
			
		||||
	handledNum := atomic.LoadInt32(&s.HandledTaskNum)
 | 
			
		||||
	return handledNum < s.maxHandleTaskNum
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// remove the expired tasks
 | 
			
		||||
func (s *Service) checkTasks() {
 | 
			
		||||
	for k, t := range s.taskStartTimes {
 | 
			
		||||
		if time.Now().Unix()-t.Unix() > s.taskTimeout {
 | 
			
		||||
			delete(s.taskStartTimes, k)
 | 
			
		||||
			atomic.AddInt32(&s.HandledTaskNum, -1)
 | 
			
		||||
			// delete task from database
 | 
			
		||||
			s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	Id          string      `json:"id"`
 | 
			
		||||
	Action      string      `json:"action"`
 | 
			
		||||
	Status      string      `json:"status"`
 | 
			
		||||
	Prompt      string      `json:"prompt"`
 | 
			
		||||
	PromptEn    string      `json:"promptEn"`
 | 
			
		||||
	Description string      `json:"description"`
 | 
			
		||||
	SubmitTime  int64       `json:"submitTime"`
 | 
			
		||||
	StartTime   int64       `json:"startTime"`
 | 
			
		||||
	FinishTime  int64       `json:"finishTime"`
 | 
			
		||||
	Progress    string      `json:"progress"`
 | 
			
		||||
	ImageUrl    string      `json:"imageUrl"`
 | 
			
		||||
	FailReason  interface{} `json:"failReason"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
		FinalPrompt string `json:"finalPrompt"`
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Notify(job model.MidJourneyJob) error {
 | 
			
		||||
	task, err := s.Client.QueryTask(job.TaskId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 任务执行失败了
 | 
			
		||||
	if task.FailReason != "" {
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
			"progress": -1,
 | 
			
		||||
			"err_msg":  task.FailReason,
 | 
			
		||||
		})
 | 
			
		||||
		return fmt.Errorf("task failed: %v", task.FailReason)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(task.Buttons) > 0 {
 | 
			
		||||
		job.Hash = GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
	}
 | 
			
		||||
	oldProgress := job.Progress
 | 
			
		||||
	job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
 | 
			
		||||
	job.Prompt = task.PromptEn
 | 
			
		||||
	if task.ImageUrl != "" {
 | 
			
		||||
		job.OrgURL = task.ImageUrl
 | 
			
		||||
	}
 | 
			
		||||
	job.MessageId = task.Id
 | 
			
		||||
	tx := s.db.Updates(&job)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return fmt.Errorf("error with update database: %v", tx.Error)
 | 
			
		||||
	}
 | 
			
		||||
	if task.Status == "SUCCESS" {
 | 
			
		||||
		// release lock task
 | 
			
		||||
		atomic.AddInt32(&s.HandledTaskNum, -1)
 | 
			
		||||
	}
 | 
			
		||||
	// 通知前端更新任务进度
 | 
			
		||||
	if oldProgress != job.Progress {
 | 
			
		||||
		s.notifyQueue.RPush(job.UserId)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetImageHash(action string) string {
 | 
			
		||||
	split := strings.Split(action, "::")
 | 
			
		||||
	if len(split) > 5 {
 | 
			
		||||
		return split[4]
 | 
			
		||||
	}
 | 
			
		||||
	return split[len(split)-1]
 | 
			
		||||
}
 | 
			
		||||
@@ -1,8 +1,7 @@
 | 
			
		||||
package plus
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
@@ -13,53 +12,21 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
// Client MidJourney Plus Client
 | 
			
		||||
type Client struct {
 | 
			
		||||
	Config types.MidJourneyPlusConfig
 | 
			
		||||
// PlusClient MidJourney Plus ProxyClient
 | 
			
		||||
type PlusClient struct {
 | 
			
		||||
	Config types.MjPlusConfig
 | 
			
		||||
	apiURL string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClient(config types.MidJourneyPlusConfig) *Client {
 | 
			
		||||
	return &Client{Config: config, apiURL: config.ApiURL}
 | 
			
		||||
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
 | 
			
		||||
	return &PlusClient{Config: config, apiURL: config.ApiURL}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageReq struct {
 | 
			
		||||
	BotType       string   `json:"botType"`
 | 
			
		||||
	Prompt        string   `json:"prompt,omitempty"`
 | 
			
		||||
	Dimensions    string   `json:"dimensions,omitempty"`
 | 
			
		||||
	Base64Array   []string `json:"base64Array,omitempty"`
 | 
			
		||||
	AccountFilter struct {
 | 
			
		||||
		InstanceId          string        `json:"instanceId"`
 | 
			
		||||
		Modes               []interface{} `json:"modes"`
 | 
			
		||||
		Remix               bool          `json:"remix"`
 | 
			
		||||
		RemixAutoConsidered bool          `json:"remixAutoConsidered"`
 | 
			
		||||
	} `json:"accountFilter,omitempty"`
 | 
			
		||||
	NotifyHook string `json:"notifyHook"`
 | 
			
		||||
	State      string `json:"state,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageRes struct {
 | 
			
		||||
	Code        int    `json:"code"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
	Result string `json:"result"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrRes struct {
 | 
			
		||||
	Error struct {
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		BotType:     "MID_JOURNEY",
 | 
			
		||||
		Prompt:      task.Prompt,
 | 
			
		||||
		NotifyHook:  c.Config.NotifyURL,
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
@@ -94,12 +61,11 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blend 融图
 | 
			
		||||
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		BotType:     "MID_JOURNEY",
 | 
			
		||||
		Dimensions:  "SQUARE",
 | 
			
		||||
		NotifyHook:  c.Config.NotifyURL,
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
@@ -133,7 +99,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SwapFace 换脸
 | 
			
		||||
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) != 2 {
 | 
			
		||||
@@ -160,8 +126,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
		"accountFilter": gin.H{
 | 
			
		||||
			"instanceId": "",
 | 
			
		||||
		},
 | 
			
		||||
		"notifyHook": c.Config.NotifyURL,
 | 
			
		||||
		"state":      "",
 | 
			
		||||
		"state": "",
 | 
			
		||||
	}
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
@@ -183,11 +148,10 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]string{
 | 
			
		||||
		"customId":   fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":     task.MessageId,
 | 
			
		||||
		"notifyHook": c.Config.NotifyURL,
 | 
			
		||||
		"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":   task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
@@ -210,11 +174,10 @@ func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]string{
 | 
			
		||||
		"customId":   fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":     task.MessageId,
 | 
			
		||||
		"notifyHook": c.Config.NotifyURL,
 | 
			
		||||
		"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":   task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
@@ -236,32 +199,7 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type QueryRes struct {
 | 
			
		||||
	Action  string `json:"action"`
 | 
			
		||||
	Buttons []struct {
 | 
			
		||||
		CustomId string `json:"customId"`
 | 
			
		||||
		Emoji    string `json:"emoji"`
 | 
			
		||||
		Label    string `json:"label"`
 | 
			
		||||
		Style    int    `json:"style"`
 | 
			
		||||
		Type     int    `json:"type"`
 | 
			
		||||
	} `json:"buttons"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
	FailReason  string `json:"failReason"`
 | 
			
		||||
	FinishTime  int    `json:"finishTime"`
 | 
			
		||||
	Id          string `json:"id"`
 | 
			
		||||
	ImageUrl    string `json:"imageUrl"`
 | 
			
		||||
	Progress    string `json:"progress"`
 | 
			
		||||
	Prompt      string `json:"prompt"`
 | 
			
		||||
	PromptEn    string `json:"promptEn"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
	StartTime  int    `json:"startTime"`
 | 
			
		||||
	State      string `json:"state"`
 | 
			
		||||
	Status     string `json:"status"`
 | 
			
		||||
	SubmitTime int    `json:"submitTime"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
 | 
			
		||||
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
 | 
			
		||||
	var res QueryRes
 | 
			
		||||
	r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
@@ -278,3 +216,5 @@ func (c *Client) QueryTask(taskId string) (QueryRes, error) {
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Client = &PlusClient{}
 | 
			
		||||
@@ -2,13 +2,12 @@ package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/mj/plus"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@@ -16,7 +15,7 @@ import (
 | 
			
		||||
 | 
			
		||||
// ServicePool Mj service pool
 | 
			
		||||
type ServicePool struct {
 | 
			
		||||
	services        []interface{}
 | 
			
		||||
	services        []*Service
 | 
			
		||||
	taskQueue       *store.RedisQueue
 | 
			
		||||
	notifyQueue     *store.RedisQueue
 | 
			
		||||
	db              *gorm.DB
 | 
			
		||||
@@ -24,8 +23,10 @@ type ServicePool struct {
 | 
			
		||||
	Clients         *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
 | 
			
		||||
	services := make([]interface{}, 0)
 | 
			
		||||
	services := make([]*Service, 0)
 | 
			
		||||
	taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
 | 
			
		||||
	notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
 | 
			
		||||
 | 
			
		||||
@@ -33,45 +34,26 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		client := plus.NewClient(config)
 | 
			
		||||
		name := fmt.Sprintf("mj-service-plus-%d", k)
 | 
			
		||||
		servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
 | 
			
		||||
		cli := NewPlusClient(config)
 | 
			
		||||
		name := fmt.Sprintf("mj-plus-service-%d", k)
 | 
			
		||||
		service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
 | 
			
		||||
		go func() {
 | 
			
		||||
			servicePlus.Run()
 | 
			
		||||
			service.Run()
 | 
			
		||||
		}()
 | 
			
		||||
		services = append(services, servicePlus)
 | 
			
		||||
		services = append(services, service)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(services) == 0 {
 | 
			
		||||
		// create mj client and service
 | 
			
		||||
		for k, config := range appConfig.MjConfigs {
 | 
			
		||||
			if config.Enabled == false {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			// create mj client
 | 
			
		||||
			client := NewClient(config, appConfig.ProxyURL)
 | 
			
		||||
 | 
			
		||||
			name := fmt.Sprintf("MjService-%d", k)
 | 
			
		||||
			// create mj service
 | 
			
		||||
			service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
 | 
			
		||||
			botName := fmt.Sprintf("MjBot-%d", k)
 | 
			
		||||
			bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			err = bot.Run()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// run mj service
 | 
			
		||||
			go func() {
 | 
			
		||||
				service.Run()
 | 
			
		||||
			}()
 | 
			
		||||
 | 
			
		||||
			services = append(services, service)
 | 
			
		||||
	for k, config := range appConfig.MjProxyConfigs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		cli := NewProxyClient(config)
 | 
			
		||||
		name := fmt.Sprintf("mj-proxy-service-%d", k)
 | 
			
		||||
		service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
 | 
			
		||||
		go func() {
 | 
			
		||||
			service.Run()
 | 
			
		||||
		}()
 | 
			
		||||
		services = append(services, service)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &ServicePool{
 | 
			
		||||
@@ -92,11 +74,11 @@ func (p *ServicePool) CheckTaskNotify() {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			client := p.Clients.Get(userId)
 | 
			
		||||
			if client == nil {
 | 
			
		||||
			cli := p.Clients.Get(userId)
 | 
			
		||||
			if cli == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = client.Send([]byte("Task Updated"))
 | 
			
		||||
			err = cli.Send([]byte("Task Updated"))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
@@ -122,10 +104,10 @@ func (p *ServicePool) DownloadImages() {
 | 
			
		||||
				logger.Infof("try to download image: %s", v.OrgURL)
 | 
			
		||||
				var imgURL string
 | 
			
		||||
				var err error
 | 
			
		||||
				if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
 | 
			
		||||
				if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
 | 
			
		||||
					task, _ := servicePlus.Client.QueryTask(v.TaskId)
 | 
			
		||||
					if len(task.Buttons) > 0 {
 | 
			
		||||
						v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
						v.Hash = GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
					}
 | 
			
		||||
					imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
 | 
			
		||||
				} else {
 | 
			
		||||
@@ -141,11 +123,11 @@ func (p *ServicePool) DownloadImages() {
 | 
			
		||||
				v.ImgURL = imgURL
 | 
			
		||||
				p.db.Updates(&v)
 | 
			
		||||
 | 
			
		||||
				client := p.Clients.Get(uint(v.UserId))
 | 
			
		||||
				if client == nil {
 | 
			
		||||
				cli := p.Clients.Get(uint(v.UserId))
 | 
			
		||||
				if cli == nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				err = client.Send([]byte("Task Updated"))
 | 
			
		||||
				err = cli.Send([]byte("Task Updated"))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
@@ -167,25 +149,6 @@ func (p *ServicePool) HasAvailableService() bool {
 | 
			
		||||
	return len(p.services) > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) Notify(data plus.CBReq) error {
 | 
			
		||||
	logger.Debugf("收到任务回调:%+v", data)
 | 
			
		||||
	var job model.MidJourneyJob
 | 
			
		||||
	res := p.db.Where("task_id = ?", data.Id).First(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return fmt.Errorf("非法任务:%s", data.Id)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 任务已经拉取完成
 | 
			
		||||
	if job.Progress == 100 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
 | 
			
		||||
		return servicePlus.Notify(job)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SyncTaskProgress 异步拉取任务
 | 
			
		||||
func (p *ServicePool) SyncTaskProgress() {
 | 
			
		||||
	go func() {
 | 
			
		||||
@@ -222,11 +185,7 @@ func (p *ServicePool) SyncTaskProgress() {
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if !strings.HasPrefix(job.ChannelId, "mj-service-plus") {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
 | 
			
		||||
				if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
 | 
			
		||||
					_ = servicePlus.Notify(job)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
@@ -236,12 +195,10 @@ func (p *ServicePool) SyncTaskProgress() {
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) getServicePlus(name string) *plus.Service {
 | 
			
		||||
func (p *ServicePool) getService(name string) *Service {
 | 
			
		||||
	for _, s := range p.services {
 | 
			
		||||
		if servicePlus, ok := s.(*plus.Service); ok {
 | 
			
		||||
			if servicePlus.Name == name {
 | 
			
		||||
				return servicePlus
 | 
			
		||||
			}
 | 
			
		||||
		if s.Name == name {
 | 
			
		||||
			return s
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										176
									
								
								api/service/mj/proxy_client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								api/service/mj/proxy_client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,176 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"io"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ProxyClient MidJourney Proxy Client
 | 
			
		||||
type ProxyClient struct {
 | 
			
		||||
	Config types.MjProxyConfig
 | 
			
		||||
	apiURL string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
 | 
			
		||||
	return &ProxyClient{Config: config, apiURL: config.ApiURL}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		Prompt:      task.Prompt,
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download image: ", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		all, err := io.ReadAll(r.Body)
 | 
			
		||||
		logger.Info(string(all))
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blend 融图
 | 
			
		||||
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		Dimensions:  "SQUARE",
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		for _, imgURL := range task.ImgArr {
 | 
			
		||||
			imageData, err := utils.DownloadImage(imgURL, "")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
			} else {
 | 
			
		||||
				body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SwapFace 换脸
 | 
			
		||||
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
 | 
			
		||||
	return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]interface{}{
 | 
			
		||||
		"action": "UPSCALE",
 | 
			
		||||
		"index":  task.Index,
 | 
			
		||||
		"taskId": task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]interface{}{
 | 
			
		||||
		"action": "VARIATION",
 | 
			
		||||
		"index":  task.Index,
 | 
			
		||||
		"taskId": task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
 | 
			
		||||
	var res QueryRes
 | 
			
		||||
	r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return QueryRes{}, errors.New("error status:" + r.Status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Client = &ProxyClient{}
 | 
			
		||||
@@ -16,24 +16,24 @@ import (
 | 
			
		||||
 | 
			
		||||
// Service MJ 绘画服务
 | 
			
		||||
type Service struct {
 | 
			
		||||
	name             string  // service name
 | 
			
		||||
	client           *Client // MJ client
 | 
			
		||||
	Name             string // service Name
 | 
			
		||||
	Client           Client // MJ Client
 | 
			
		||||
	taskQueue        *store.RedisQueue
 | 
			
		||||
	notifyQueue      *store.RedisQueue
 | 
			
		||||
	db               *gorm.DB
 | 
			
		||||
	maxHandleTaskNum int32             // max task number current service can handle
 | 
			
		||||
	handledTaskNum   int32             // already handled task number
 | 
			
		||||
	HandledTaskNum   int32             // already handled task number
 | 
			
		||||
	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout
 | 
			
		||||
	taskTimeout      int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
 | 
			
		||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		name:             name,
 | 
			
		||||
		Name:             name,
 | 
			
		||||
		db:               db,
 | 
			
		||||
		taskQueue:        taskQueue,
 | 
			
		||||
		notifyQueue:      notifyQueue,
 | 
			
		||||
		client:           client,
 | 
			
		||||
		Client:           cli,
 | 
			
		||||
		taskTimeout:      timeout,
 | 
			
		||||
		maxHandleTaskNum: maxTaskNum,
 | 
			
		||||
		taskStartTimes:   make(map[int]time.Time, 0),
 | 
			
		||||
@@ -41,7 +41,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Infof("Starting MidJourney job consumer for %s", s.name)
 | 
			
		||||
	logger.Infof("Starting MidJourney job consumer for %s", s.Name)
 | 
			
		||||
	for {
 | 
			
		||||
		s.checkTasks()
 | 
			
		||||
		if !s.canHandleTask() {
 | 
			
		||||
@@ -58,65 +58,72 @@ func (s *Service) Run() {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// if it's reference message, check if it's this channel's  message
 | 
			
		||||
		if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
 | 
			
		||||
		//  如果配置了多个中转平台的 API KEY
 | 
			
		||||
		// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
 | 
			
		||||
		if task.ChannelId != "" && task.ChannelId != s.Name {
 | 
			
		||||
			logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
 | 
			
		||||
			s.taskQueue.RPush(task)
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 翻译提示词
 | 
			
		||||
		if utils.HasChinese(task.Prompt) {
 | 
			
		||||
		if utils.HasChinese(task.Prompt) && strings.HasPrefix(s.Name, "mj-proxy-service") {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Prompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
 | 
			
		||||
		logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
 | 
			
		||||
		var res ImageRes
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
			err = s.client.Imagine(task)
 | 
			
		||||
			res, err = s.Client.Imagine(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskUpscale:
 | 
			
		||||
			err = s.client.Upscale(task)
 | 
			
		||||
			res, err = s.Client.Upscale(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskVariation:
 | 
			
		||||
			err = s.client.Variation(task)
 | 
			
		||||
			res, err = s.Client.Variation(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskBlend:
 | 
			
		||||
			err = s.client.Blend(task)
 | 
			
		||||
			res, err = s.Client.Blend(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskSwapFace:
 | 
			
		||||
			err = s.client.SwapFace(task)
 | 
			
		||||
			res, err = s.Client.SwapFace(task)
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err.Error())
 | 
			
		||||
		var job model.MidJourneyJob
 | 
			
		||||
		s.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
		if err != nil || (res.Code != 1 && res.Code != 22) {
 | 
			
		||||
			errMsg := fmt.Sprintf("%v,%s", err, res.Description)
 | 
			
		||||
			logger.Error("绘画任务执行失败:", errMsg)
 | 
			
		||||
			job.Progress = -1
 | 
			
		||||
			job.ErrMsg = errMsg
 | 
			
		||||
			// update the task progress
 | 
			
		||||
			s.db.Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
				"progress": -1,
 | 
			
		||||
				"err_msg":  err.Error(),
 | 
			
		||||
			})
 | 
			
		||||
			s.db.Updates(&job)
 | 
			
		||||
			// 任务失败,通知前端
 | 
			
		||||
			s.notifyQueue.RPush(task.UserId)
 | 
			
		||||
			// restore img_call quota
 | 
			
		||||
			if task.Type.String() != types.TaskUpscale.String() {
 | 
			
		||||
				s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("Task Executed: %+v", task)
 | 
			
		||||
		logger.Infof("任务提交成功:%+v", res)
 | 
			
		||||
		// lock the task until the execute timeout
 | 
			
		||||
		s.taskStartTimes[int(task.Id)] = time.Now()
 | 
			
		||||
		atomic.AddInt32(&s.handledTaskNum, 1)
 | 
			
		||||
 | 
			
		||||
		atomic.AddInt32(&s.HandledTaskNum, 1)
 | 
			
		||||
		// 更新任务 ID/频道
 | 
			
		||||
		job.TaskId = res.Result
 | 
			
		||||
		job.ChannelId = s.Name
 | 
			
		||||
		s.db.Updates(&job)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// check if current service instance can handle more task
 | 
			
		||||
func (s *Service) canHandleTask() bool {
 | 
			
		||||
	handledNum := atomic.LoadInt32(&s.handledTaskNum)
 | 
			
		||||
	handledNum := atomic.LoadInt32(&s.HandledTaskNum)
 | 
			
		||||
	return handledNum < s.maxHandleTaskNum
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -125,65 +132,75 @@ func (s *Service) checkTasks() {
 | 
			
		||||
	for k, t := range s.taskStartTimes {
 | 
			
		||||
		if time.Now().Unix()-t.Unix() > s.taskTimeout {
 | 
			
		||||
			delete(s.taskStartTimes, k)
 | 
			
		||||
			atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
			atomic.AddInt32(&s.HandledTaskNum, -1)
 | 
			
		||||
			// delete task from database
 | 
			
		||||
			s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Notify(data CBReq) {
 | 
			
		||||
	// extract the task ID
 | 
			
		||||
	split := strings.Split(data.Prompt, " ")
 | 
			
		||||
	var job model.MidJourneyJob
 | 
			
		||||
	res := s.db.Where("message_id = ?", data.MessageId).First(&job)
 | 
			
		||||
	if res.Error == nil && data.Status == Finished {
 | 
			
		||||
		logger.Warn("重复消息:", data.MessageId)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := s.db.Session(&gorm.Session{}).Where("progress < ?", 100).Order("id ASC")
 | 
			
		||||
	if data.ReferenceId != "" {
 | 
			
		||||
		tx = tx.Where("reference_id = ?", data.ReferenceId)
 | 
			
		||||
	} else {
 | 
			
		||||
		tx = tx.Where("task_id = ?", split[0])
 | 
			
		||||
	}
 | 
			
		||||
	// fixed: 修复 U/V 操作任务混淆覆盖的 Bug
 | 
			
		||||
	if strings.Contains(data.Prompt, "** - Image #") { // for upscale
 | 
			
		||||
		tx = tx.Where("type = ?", types.TaskUpscale.String())
 | 
			
		||||
	} else if strings.Contains(data.Prompt, "** - Variations (Strong)") { // for Variations
 | 
			
		||||
		tx = tx.Where("type = ?", types.TaskVariation.String())
 | 
			
		||||
	}
 | 
			
		||||
	res = tx.First(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Warn("非法任务:", res.Error)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	job.ChannelId = data.ChannelId
 | 
			
		||||
	job.MessageId = data.MessageId
 | 
			
		||||
	job.ReferenceId = data.ReferenceId
 | 
			
		||||
	job.Progress = data.Progress
 | 
			
		||||
	job.Prompt = data.Prompt
 | 
			
		||||
	job.Hash = data.Image.Hash
 | 
			
		||||
	if s.client.Config.UseCDN {
 | 
			
		||||
		job.UseProxy = true
 | 
			
		||||
		job.OrgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL)
 | 
			
		||||
	} else {
 | 
			
		||||
		job.OrgURL = data.Image.URL
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res = s.db.Updates(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update job: ", res.Error)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if data.Status == Finished {
 | 
			
		||||
		// release lock task
 | 
			
		||||
		atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.notifyQueue.RPush(job.UserId)
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	Id          string      `json:"id"`
 | 
			
		||||
	Action      string      `json:"action"`
 | 
			
		||||
	Status      string      `json:"status"`
 | 
			
		||||
	Prompt      string      `json:"prompt"`
 | 
			
		||||
	PromptEn    string      `json:"promptEn"`
 | 
			
		||||
	Description string      `json:"description"`
 | 
			
		||||
	SubmitTime  int64       `json:"submitTime"`
 | 
			
		||||
	StartTime   int64       `json:"startTime"`
 | 
			
		||||
	FinishTime  int64       `json:"finishTime"`
 | 
			
		||||
	Progress    string      `json:"progress"`
 | 
			
		||||
	ImageUrl    string      `json:"imageUrl"`
 | 
			
		||||
	FailReason  interface{} `json:"failReason"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
		FinalPrompt string `json:"finalPrompt"`
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Notify(job model.MidJourneyJob) error {
 | 
			
		||||
	task, err := s.Client.QueryTask(job.TaskId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 任务执行失败了
 | 
			
		||||
	if task.FailReason != "" {
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
			"progress": -1,
 | 
			
		||||
			"err_msg":  task.FailReason,
 | 
			
		||||
		})
 | 
			
		||||
		return fmt.Errorf("task failed: %v", task.FailReason)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(task.Buttons) > 0 {
 | 
			
		||||
		job.Hash = GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
	}
 | 
			
		||||
	oldProgress := job.Progress
 | 
			
		||||
	job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
 | 
			
		||||
	job.Prompt = task.PromptEn
 | 
			
		||||
	if task.ImageUrl != "" {
 | 
			
		||||
		job.OrgURL = task.ImageUrl
 | 
			
		||||
	}
 | 
			
		||||
	job.MessageId = task.Id
 | 
			
		||||
	tx := s.db.Updates(&job)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return fmt.Errorf("error with update database: %v", tx.Error)
 | 
			
		||||
	}
 | 
			
		||||
	if task.Status == "SUCCESS" {
 | 
			
		||||
		// release lock task
 | 
			
		||||
		atomic.AddInt32(&s.HandledTaskNum, -1)
 | 
			
		||||
	}
 | 
			
		||||
	// 通知前端更新任务进度
 | 
			
		||||
	if oldProgress != job.Progress {
 | 
			
		||||
		s.notifyQueue.RPush(job.UserId)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetImageHash(action string) string {
 | 
			
		||||
	split := strings.Split(action, "::")
 | 
			
		||||
	if len(split) > 5 {
 | 
			
		||||
		return split[4]
 | 
			
		||||
	}
 | 
			
		||||
	return split[len(split)-1]
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,35 +0,0 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ApplicationID string = "936929561302675456"
 | 
			
		||||
	SessionID     string = "ea8816d857ba9ae2f74c59ae1a953afe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type InteractionsRequest struct {
 | 
			
		||||
	Type          int            `json:"type"`
 | 
			
		||||
	ApplicationID string         `json:"application_id"`
 | 
			
		||||
	MessageFlags  int            `json:"message_flags,omitempty"`
 | 
			
		||||
	MessageID     string         `json:"message_id,omitempty"`
 | 
			
		||||
	GuildID       string         `json:"guild_id"`
 | 
			
		||||
	ChannelID     string         `json:"channel_id"`
 | 
			
		||||
	SessionID     string         `json:"session_id"`
 | 
			
		||||
	Data          map[string]any `json:"data"`
 | 
			
		||||
	Nonce         string         `json:"nonce,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InteractionsResult struct {
 | 
			
		||||
	Code    int `json:"code"`
 | 
			
		||||
	Message string
 | 
			
		||||
	Error   map[string]any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	ChannelId   string     `json:"channel_id"`
 | 
			
		||||
	MessageId   string     `json:"message_id"`
 | 
			
		||||
	ReferenceId string     `json:"reference_id"`
 | 
			
		||||
	Image       Image      `json:"image"`
 | 
			
		||||
	Content     string     `json:"content"`
 | 
			
		||||
	Prompt      string     `json:"prompt"`
 | 
			
		||||
	Status      TaskStatus `json:"status"`
 | 
			
		||||
	Progress    int        `json:"progress"`
 | 
			
		||||
}
 | 
			
		||||
@@ -1,5 +1,11 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	text := "一只 蜗牛在树干上爬,阳光透过树叶照在蜗牛的背上 --ar 1:1 --iw 0.250000 --v 6"
 | 
			
		||||
	fmt.Println(utils.HasChinese(text))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user