mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
feat: change midjourney origin implements, replace midjourney bot with midjourney-proxy
This commit is contained in:
parent
b5947545cb
commit
4fb2c5803c
@ -1,4 +1,13 @@
|
|||||||
# 更新日志
|
# 更新日志
|
||||||
|
## v4.0.1
|
||||||
|
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion 发行版,稳定性更强一些
|
||||||
|
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容 MJ-Plus 中转
|
||||||
|
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
|
||||||
|
* 功能新增:支持前端菜单可以配置
|
||||||
|
* 功能优化:手机端支持免登录预览功能
|
||||||
|
* 功能新增:手机端支持 Stable-Diffusion 绘画
|
||||||
|
* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换
|
||||||
|
|
||||||
## v4.0.0
|
## v4.0.0
|
||||||
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
|
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
|
||||||
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力...
|
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力...
|
||||||
|
@ -5,22 +5,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
Path string `toml:"-"`
|
Path string `toml:"-"`
|
||||||
Listen string
|
Listen string
|
||||||
Session Session
|
Session Session
|
||||||
AdminSession Session
|
AdminSession Session
|
||||||
ProxyURL string
|
ProxyURL string
|
||||||
MysqlDns string // mysql 连接地址
|
MysqlDns string // mysql 连接地址
|
||||||
StaticDir string // 静态资源目录
|
StaticDir string // 静态资源目录
|
||||||
StaticUrl string // 静态资源 URL
|
StaticUrl string // 静态资源 URL
|
||||||
Redis RedisConfig // redis 连接信息
|
Redis RedisConfig // redis 连接信息
|
||||||
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
||||||
SMS SMSConfig // send mobile message config
|
SMS SMSConfig // send mobile message config
|
||||||
OSS OSSConfig // OSS config
|
OSS OSSConfig // OSS config
|
||||||
MjConfigs []MidJourneyConfig // mj AI draw service pool
|
MjProxyConfigs []MjProxyConfig // MJ proxy config
|
||||||
MjPlusConfigs []MidJourneyPlusConfig // MJ plus config
|
MjPlusConfigs []MjPlusConfig // MJ plus config
|
||||||
WeChatBot bool // 是否启用微信机器人
|
WeChatBot bool // 是否启用微信机器人
|
||||||
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
SdConfigs []StableDiffusionConfig // sd AI draw service pool
|
||||||
|
|
||||||
XXLConfig XXLConfig
|
XXLConfig XXLConfig
|
||||||
AlipayConfig AlipayConfig
|
AlipayConfig AlipayConfig
|
||||||
@ -43,16 +43,11 @@ type ChatPlusApiConfig struct {
|
|||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidJourneyConfig struct {
|
type MjProxyConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
UserToken string
|
ApiURL string // api 地址
|
||||||
BotToken string
|
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||||
GuildId string // Server ID
|
ApiKey string
|
||||||
ChanelId string // Chanel ID
|
|
||||||
UseCDN bool
|
|
||||||
ImgCdnURL string // 图片反代加速地址
|
|
||||||
DiscordAPI string
|
|
||||||
DiscordGateway string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type StableDiffusionConfig struct {
|
type StableDiffusionConfig struct {
|
||||||
@ -62,12 +57,11 @@ type StableDiffusionConfig struct {
|
|||||||
ApiKey string
|
ApiKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidJourneyPlusConfig struct {
|
type MjPlusConfig struct {
|
||||||
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
|
||||||
ApiURL string // api 地址
|
ApiURL string // api 地址
|
||||||
Mode string // 绘画模式,可选值:fast/turbo/relax
|
Mode string // 绘画模式,可选值:fast/turbo/relax
|
||||||
ApiKey string
|
ApiKey string
|
||||||
NotifyURL string // 任务进度更新回调地址
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AlipayConfig struct {
|
type AlipayConfig struct {
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service"
|
"chatplus/service"
|
||||||
"chatplus/service/mj"
|
"chatplus/service/mj"
|
||||||
"chatplus/service/mj/plus"
|
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
@ -454,27 +453,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
|||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify MidJourney Plus 服务任务回调处理
|
|
||||||
func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|
||||||
var data plus.CBReq
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
logger.Error("非法任务回调:%+v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err := h.pool.Notify(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
} else {
|
|
||||||
userId := h.GetLoginUserId(c)
|
|
||||||
client := h.pool.Clients.Get(userId)
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Send([]byte("Task Updated"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Publish 发布图片到画廊显示
|
// Publish 发布图片到画廊显示
|
||||||
func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
func (h *MidJourneyHandler) Publish(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
|
@ -252,7 +252,6 @@ func main() {
|
|||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
group.GET("imgWall", h.ImgWall)
|
group.GET("imgWall", h.ImgWall)
|
||||||
group.POST("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
group.POST("notify", h.Notify)
|
|
||||||
group.POST("publish", h.Publish)
|
group.POST("publish", h.Publish)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||||
|
@ -1,233 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core/types"
|
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"chatplus/utils"
|
|
||||||
discordgo "github.com/bg5t/mydiscordgo"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MidJourney 机器人
|
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
|
||||||
|
|
||||||
type Bot struct {
|
|
||||||
config types.MidJourneyConfig
|
|
||||||
bot *discordgo.Session
|
|
||||||
name string
|
|
||||||
service *Service
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) {
|
|
||||||
bot, err := discordgo.New("Bot " + config.BotToken)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// use CDN reverse proxy
|
|
||||||
if config.UseCDN {
|
|
||||||
discordgo.SetEndpointDiscord(config.DiscordAPI)
|
|
||||||
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
|
|
||||||
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
|
|
||||||
bot.MjGateway = config.DiscordGateway + "/"
|
|
||||||
} else { // use proxy
|
|
||||||
discordgo.SetEndpointDiscord("https://discord.com")
|
|
||||||
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
|
|
||||||
discordgo.SetEndpointStatus("https://discord.com/api/v2/")
|
|
||||||
bot.MjGateway = "wss://gateway.discord.gg"
|
|
||||||
|
|
||||||
if proxy != "" {
|
|
||||||
proxy, _ := url.Parse(proxy)
|
|
||||||
bot.Client = &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
bot.Dialer = &websocket.Dialer{
|
|
||||||
Proxy: http.ProxyURL(proxy),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Bot{
|
|
||||||
config: config,
|
|
||||||
bot: bot,
|
|
||||||
name: name,
|
|
||||||
service: service,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) Run() error {
|
|
||||||
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
|
|
||||||
b.bot.AddHandler(b.messageCreate)
|
|
||||||
b.bot.AddHandler(b.messageUpdate)
|
|
||||||
|
|
||||||
logger.Infof("Starting MidJourney %s", b.name)
|
|
||||||
err := b.bot.Open()
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger.Infof("Starting MidJourney %s successfully!", b.name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TaskStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
Start = TaskStatus("Started")
|
|
||||||
Running = TaskStatus("Running")
|
|
||||||
Stopped = TaskStatus("Stopped")
|
|
||||||
Finished = TaskStatus("Finished")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Image struct {
|
|
||||||
URL string `json:"url"`
|
|
||||||
ProxyURL string `json:"proxy_url"`
|
|
||||||
Filename string `json:"filename"`
|
|
||||||
Width int `json:"width"`
|
|
||||||
Height int `json:"height"`
|
|
||||||
Size int `json:"size"`
|
|
||||||
Hash string `json:"hash"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
|
||||||
// ignore messages for other channels
|
|
||||||
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// ignore messages for self
|
|
||||||
if m.Author == nil || m.Author.ID == s.State.User.ID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
|
|
||||||
var referenceId = ""
|
|
||||||
if m.ReferencedMessage != nil {
|
|
||||||
referenceId = m.ReferencedMessage.ID
|
|
||||||
}
|
|
||||||
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
|
|
||||||
// parse content
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: m.ChannelID,
|
|
||||||
MessageId: m.ID,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Prompt: extractPrompt(m.Content),
|
|
||||||
Content: m.Content,
|
|
||||||
Progress: 0,
|
|
||||||
Status: Start}
|
|
||||||
b.service.Notify(req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
|
||||||
// ignore messages for other channels
|
|
||||||
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// ignore messages for self
|
|
||||||
if m.Author == nil || m.Author.ID == s.State.User.ID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
|
|
||||||
|
|
||||||
var referenceId = ""
|
|
||||||
if m.ReferencedMessage != nil {
|
|
||||||
referenceId = m.ReferencedMessage.ID
|
|
||||||
}
|
|
||||||
if strings.Contains(m.Content, "(Stopped)") {
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: m.ChannelID,
|
|
||||||
MessageId: m.ID,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Prompt: extractPrompt(m.Content),
|
|
||||||
Content: m.Content,
|
|
||||||
Progress: extractProgress(m.Content),
|
|
||||||
Status: Stopped}
|
|
||||||
b.service.Notify(req)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
|
||||||
progress := extractProgress(content)
|
|
||||||
var status TaskStatus
|
|
||||||
if progress == 100 {
|
|
||||||
status = Finished
|
|
||||||
} else {
|
|
||||||
status = Running
|
|
||||||
}
|
|
||||||
for _, attachment := range attachments {
|
|
||||||
if attachment.Width == 0 || attachment.Height == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
image := Image{
|
|
||||||
URL: attachment.URL,
|
|
||||||
Height: attachment.Height,
|
|
||||||
ProxyURL: attachment.ProxyURL,
|
|
||||||
Width: attachment.Width,
|
|
||||||
Size: attachment.Size,
|
|
||||||
Filename: attachment.Filename,
|
|
||||||
Hash: extractHashFromFilename(attachment.Filename),
|
|
||||||
}
|
|
||||||
req := CBReq{
|
|
||||||
ChannelId: channelId,
|
|
||||||
MessageId: messageId,
|
|
||||||
ReferenceId: referenceId,
|
|
||||||
Image: image,
|
|
||||||
Prompt: extractPrompt(content),
|
|
||||||
Content: content,
|
|
||||||
Progress: progress,
|
|
||||||
Status: status,
|
|
||||||
}
|
|
||||||
b.service.Notify(req)
|
|
||||||
break // only get one image
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract prompt from string
|
|
||||||
func extractPrompt(input string) string {
|
|
||||||
pattern := `\*\*(.*?)\*\*`
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
matches := re.FindStringSubmatch(input)
|
|
||||||
if len(matches) > 1 {
|
|
||||||
return strings.TrimSpace(matches[1])
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractProgress(input string) int {
|
|
||||||
pattern := `\((\d+)\%\)`
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
matches := re.FindStringSubmatch(input)
|
|
||||||
if len(matches) > 1 {
|
|
||||||
return utils.IntValue(matches[1], 0)
|
|
||||||
}
|
|
||||||
return 100
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractHashFromFilename(filename string) string {
|
|
||||||
if !strings.HasSuffix(filename, ".png") {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
index := strings.LastIndex(filename, "_")
|
|
||||||
if index != -1 {
|
|
||||||
return filename[index+1 : len(filename)-4]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
@ -1,159 +1,61 @@
|
|||||||
package mj
|
package mj
|
||||||
|
|
||||||
import (
|
import "chatplus/core/types"
|
||||||
"chatplus/core/types"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
type Client interface {
|
||||||
)
|
Imagine(task types.MjTask) (ImageRes, error)
|
||||||
|
Blend(task types.MjTask) (ImageRes, error)
|
||||||
// MidJourney client
|
SwapFace(task types.MjTask) (ImageRes, error)
|
||||||
|
Upscale(task types.MjTask) (ImageRes, error)
|
||||||
type Client struct {
|
Variation(task types.MjTask) (ImageRes, error)
|
||||||
client *req.Client
|
QueryTask(taskId string) (QueryRes, error)
|
||||||
Config types.MidJourneyConfig
|
|
||||||
apiURL string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
|
type ImageReq struct {
|
||||||
client := req.C().SetTimeout(10 * time.Second)
|
BotType string `json:"botType,omitempty"`
|
||||||
var apiURL string
|
Prompt string `json:"prompt,omitempty"`
|
||||||
// set proxy URL
|
Dimensions string `json:"dimensions,omitempty"`
|
||||||
if config.UseCDN {
|
Base64Array []string `json:"base64Array,omitempty"`
|
||||||
apiURL = config.DiscordAPI + "/api/v9/interactions"
|
AccountFilter interface{} `json:"accountFilter,omitempty"`
|
||||||
} else {
|
NotifyHook string `json:"notifyHook,omitempty"`
|
||||||
apiURL = "https://discord.com/api/v9/interactions"
|
State string `json:"state,omitempty"`
|
||||||
if proxy != "" {
|
|
||||||
client.SetProxyURL(proxy)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{client: client, Config: config, apiURL: apiURL}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Imagine(task types.MjTask) error {
|
type ImageRes struct {
|
||||||
interactionsReq := &InteractionsRequest{
|
Code int `json:"code"`
|
||||||
Type: 2,
|
Description string `json:"description"`
|
||||||
ApplicationID: ApplicationID,
|
Properties struct {
|
||||||
GuildID: c.Config.GuildId,
|
} `json:"properties"`
|
||||||
ChannelID: c.Config.ChanelId,
|
Result string `json:"result"`
|
||||||
SessionID: SessionID,
|
|
||||||
Data: map[string]any{
|
|
||||||
"version": "1166847114203123795",
|
|
||||||
"id": "938956540159881230",
|
|
||||||
"name": "imagine",
|
|
||||||
"type": "1",
|
|
||||||
"options": []map[string]any{
|
|
||||||
{
|
|
||||||
"type": 3,
|
|
||||||
"name": "prompt",
|
|
||||||
"value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"application_command": map[string]any{
|
|
||||||
"id": "938956540159881230",
|
|
||||||
"application_id": ApplicationID,
|
|
||||||
"version": "1118961510123847772",
|
|
||||||
"default_permission": true,
|
|
||||||
"default_member_permissions": nil,
|
|
||||||
"type": 1,
|
|
||||||
"nsfw": false,
|
|
||||||
"name": "imagine",
|
|
||||||
"description": "Create images with Midjourney",
|
|
||||||
"dm_permission": true,
|
|
||||||
"options": []map[string]any{
|
|
||||||
{
|
|
||||||
"type": 3,
|
|
||||||
"name": "prompt",
|
|
||||||
"description": "The prompt to imagine",
|
|
||||||
"required": true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"attachments": []any{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(interactionsReq).
|
|
||||||
Post(c.apiURL)
|
|
||||||
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %w%v", err, r.Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Blend(task types.MjTask) error {
|
type ErrRes struct {
|
||||||
return errors.New("function not implemented")
|
Error struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SwapFace(task types.MjTask) error {
|
type QueryRes struct {
|
||||||
return errors.New("function not implemented")
|
Action string `json:"action"`
|
||||||
}
|
Buttons []struct {
|
||||||
|
CustomId string `json:"customId"`
|
||||||
// Upscale 放大指定的图片
|
Emoji string `json:"emoji"`
|
||||||
func (c *Client) Upscale(task types.MjTask) error {
|
Label string `json:"label"`
|
||||||
flags := 0
|
Style int `json:"style"`
|
||||||
interactionsReq := &InteractionsRequest{
|
Type int `json:"type"`
|
||||||
Type: 3,
|
} `json:"buttons"`
|
||||||
ApplicationID: ApplicationID,
|
Description string `json:"description"`
|
||||||
GuildID: c.Config.GuildId,
|
FailReason string `json:"failReason"`
|
||||||
ChannelID: c.Config.ChanelId,
|
FinishTime int `json:"finishTime"`
|
||||||
MessageFlags: flags,
|
Id string `json:"id"`
|
||||||
MessageID: task.MessageId,
|
ImageUrl string `json:"imageUrl"`
|
||||||
SessionID: SessionID,
|
Progress string `json:"progress"`
|
||||||
Data: map[string]any{
|
Prompt string `json:"prompt"`
|
||||||
"component_type": 2,
|
PromptEn string `json:"promptEn"`
|
||||||
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
Properties struct {
|
||||||
},
|
} `json:"properties"`
|
||||||
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
StartTime int `json:"startTime"`
|
||||||
}
|
State string `json:"state"`
|
||||||
|
Status string `json:"status"`
|
||||||
var res InteractionsResult
|
SubmitTime int `json:"submitTime"`
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(interactionsReq).
|
|
||||||
SetErrorResult(&res).
|
|
||||||
Post(c.apiURL)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
|
||||||
func (c *Client) Variation(task types.MjTask) error {
|
|
||||||
flags := 0
|
|
||||||
interactionsReq := &InteractionsRequest{
|
|
||||||
Type: 3,
|
|
||||||
ApplicationID: ApplicationID,
|
|
||||||
GuildID: c.Config.GuildId,
|
|
||||||
ChannelID: c.Config.ChanelId,
|
|
||||||
MessageFlags: flags,
|
|
||||||
MessageID: task.MessageId,
|
|
||||||
SessionID: SessionID,
|
|
||||||
Data: map[string]any{
|
|
||||||
"component_type": 2,
|
|
||||||
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
|
||||||
},
|
|
||||||
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
|
||||||
}
|
|
||||||
|
|
||||||
var res InteractionsResult
|
|
||||||
r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(interactionsReq).
|
|
||||||
SetErrorResult(&res).
|
|
||||||
Post(c.apiURL)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -1,194 +0,0 @@
|
|||||||
package plus
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core/types"
|
|
||||||
"chatplus/store"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/utils"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Service MJ 绘画服务
|
|
||||||
type Service struct {
|
|
||||||
Name string // service Name
|
|
||||||
Client *Client // MJ Client
|
|
||||||
taskQueue *store.RedisQueue
|
|
||||||
notifyQueue *store.RedisQueue
|
|
||||||
db *gorm.DB
|
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
|
||||||
HandledTaskNum int32 // already handled task number
|
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
|
||||||
taskTimeout int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
|
||||||
return &Service{
|
|
||||||
Name: name,
|
|
||||||
db: db,
|
|
||||||
taskQueue: taskQueue,
|
|
||||||
notifyQueue: notifyQueue,
|
|
||||||
Client: client,
|
|
||||||
taskTimeout: timeout,
|
|
||||||
maxHandleTaskNum: maxTaskNum,
|
|
||||||
taskStartTimes: make(map[int]time.Time, 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Run() {
|
|
||||||
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
|
||||||
for {
|
|
||||||
s.checkTasks()
|
|
||||||
if !s.canHandleTask() {
|
|
||||||
// current service is full, can not handle more task
|
|
||||||
// waiting for running task finish
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.MjTask
|
|
||||||
err := s.taskQueue.LPop(&task)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("taking task with error: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// if it's reference message, check if it's this channel's message
|
|
||||||
//if task.ChannelId != "" && task.ChannelId != s.Name {
|
|
||||||
// logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
|
|
||||||
// s.taskQueue.RPush(task)
|
|
||||||
// time.Sleep(time.Second)
|
|
||||||
// continue
|
|
||||||
//}
|
|
||||||
|
|
||||||
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
|
|
||||||
var res ImageRes
|
|
||||||
switch task.Type {
|
|
||||||
case types.TaskImage:
|
|
||||||
res, err = s.Client.Imagine(task)
|
|
||||||
break
|
|
||||||
case types.TaskUpscale:
|
|
||||||
res, err = s.Client.Upscale(task)
|
|
||||||
break
|
|
||||||
case types.TaskVariation:
|
|
||||||
res, err = s.Client.Variation(task)
|
|
||||||
break
|
|
||||||
case types.TaskBlend:
|
|
||||||
res, err = s.Client.Blend(task)
|
|
||||||
break
|
|
||||||
case types.TaskSwapFace:
|
|
||||||
res, err = s.Client.SwapFace(task)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
var job model.MidJourneyJob
|
|
||||||
s.db.Where("id = ?", task.Id).First(&job)
|
|
||||||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
|
||||||
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
|
|
||||||
logger.Error("绘画任务执行失败:", errMsg)
|
|
||||||
job.Progress = -1
|
|
||||||
job.ErrMsg = errMsg
|
|
||||||
// update the task progress
|
|
||||||
s.db.Updates(&job)
|
|
||||||
// 任务失败,通知前端
|
|
||||||
s.notifyQueue.RPush(task.UserId)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Infof("任务提交成功:%+v", res)
|
|
||||||
// lock the task until the execute timeout
|
|
||||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, 1)
|
|
||||||
// 更新任务 ID/频道
|
|
||||||
job.TaskId = res.Result
|
|
||||||
job.ChannelId = s.Name
|
|
||||||
s.db.Updates(&job)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if current service instance can handle more task
|
|
||||||
func (s *Service) canHandleTask() bool {
|
|
||||||
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
|
|
||||||
return handledNum < s.maxHandleTaskNum
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the expired tasks
|
|
||||||
func (s *Service) checkTasks() {
|
|
||||||
for k, t := range s.taskStartTimes {
|
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
|
||||||
delete(s.taskStartTimes, k)
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
|
||||||
// delete task from database
|
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type CBReq struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
PromptEn string `json:"promptEn"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
SubmitTime int64 `json:"submitTime"`
|
|
||||||
StartTime int64 `json:"startTime"`
|
|
||||||
FinishTime int64 `json:"finishTime"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
ImageUrl string `json:"imageUrl"`
|
|
||||||
FailReason interface{} `json:"failReason"`
|
|
||||||
Properties struct {
|
|
||||||
FinalPrompt string `json:"finalPrompt"`
|
|
||||||
} `json:"properties"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Notify(job model.MidJourneyJob) error {
|
|
||||||
task, err := s.Client.QueryTask(job.TaskId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 任务执行失败了
|
|
||||||
if task.FailReason != "" {
|
|
||||||
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
|
||||||
"progress": -1,
|
|
||||||
"err_msg": task.FailReason,
|
|
||||||
})
|
|
||||||
return fmt.Errorf("task failed: %v", task.FailReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(task.Buttons) > 0 {
|
|
||||||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
|
||||||
}
|
|
||||||
oldProgress := job.Progress
|
|
||||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
|
||||||
job.Prompt = task.PromptEn
|
|
||||||
if task.ImageUrl != "" {
|
|
||||||
job.OrgURL = task.ImageUrl
|
|
||||||
}
|
|
||||||
job.MessageId = task.Id
|
|
||||||
tx := s.db.Updates(&job)
|
|
||||||
if tx.Error != nil {
|
|
||||||
return fmt.Errorf("error with update database: %v", tx.Error)
|
|
||||||
}
|
|
||||||
if task.Status == "SUCCESS" {
|
|
||||||
// release lock task
|
|
||||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
|
||||||
}
|
|
||||||
// 通知前端更新任务进度
|
|
||||||
if oldProgress != job.Progress {
|
|
||||||
s.notifyQueue.RPush(job.UserId)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetImageHash(action string) string {
|
|
||||||
split := strings.Split(action, "::")
|
|
||||||
if len(split) > 5 {
|
|
||||||
return split[4]
|
|
||||||
}
|
|
||||||
return split[len(split)-1]
|
|
||||||
}
|
|
@ -1,8 +1,7 @@
|
|||||||
package plus
|
package mj
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
@ -13,53 +12,21 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
// PlusClient MidJourney Plus ProxyClient
|
||||||
|
type PlusClient struct {
|
||||||
// Client MidJourney Plus Client
|
Config types.MjPlusConfig
|
||||||
type Client struct {
|
|
||||||
Config types.MidJourneyPlusConfig
|
|
||||||
apiURL string
|
apiURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(config types.MidJourneyPlusConfig) *Client {
|
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
|
||||||
return &Client{Config: config, apiURL: config.ApiURL}
|
return &PlusClient{Config: config, apiURL: config.ApiURL}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageReq struct {
|
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||||
BotType string `json:"botType"`
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
Dimensions string `json:"dimensions,omitempty"`
|
|
||||||
Base64Array []string `json:"base64Array,omitempty"`
|
|
||||||
AccountFilter struct {
|
|
||||||
InstanceId string `json:"instanceId"`
|
|
||||||
Modes []interface{} `json:"modes"`
|
|
||||||
Remix bool `json:"remix"`
|
|
||||||
RemixAutoConsidered bool `json:"remixAutoConsidered"`
|
|
||||||
} `json:"accountFilter,omitempty"`
|
|
||||||
NotifyHook string `json:"notifyHook"`
|
|
||||||
State string `json:"state,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageRes struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Properties struct {
|
|
||||||
} `json:"properties"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrRes struct {
|
|
||||||
Error struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
} `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
|
||||||
body := ImageReq{
|
body := ImageReq{
|
||||||
BotType: "MID_JOURNEY",
|
BotType: "MID_JOURNEY",
|
||||||
Prompt: task.Prompt,
|
Prompt: task.Prompt,
|
||||||
NotifyHook: c.Config.NotifyURL,
|
|
||||||
Base64Array: make([]string, 0),
|
Base64Array: make([]string, 0),
|
||||||
}
|
}
|
||||||
// 生成图片 Base64 编码
|
// 生成图片 Base64 编码
|
||||||
@ -94,12 +61,11 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Blend 融图
|
// Blend 融图
|
||||||
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
|
||||||
body := ImageReq{
|
body := ImageReq{
|
||||||
BotType: "MID_JOURNEY",
|
BotType: "MID_JOURNEY",
|
||||||
Dimensions: "SQUARE",
|
Dimensions: "SQUARE",
|
||||||
NotifyHook: c.Config.NotifyURL,
|
|
||||||
Base64Array: make([]string, 0),
|
Base64Array: make([]string, 0),
|
||||||
}
|
}
|
||||||
// 生成图片 Base64 编码
|
// 生成图片 Base64 编码
|
||||||
@ -133,7 +99,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SwapFace 换脸
|
// SwapFace 换脸
|
||||||
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
||||||
// 生成图片 Base64 编码
|
// 生成图片 Base64 编码
|
||||||
if len(task.ImgArr) != 2 {
|
if len(task.ImgArr) != 2 {
|
||||||
@ -160,8 +126,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
|||||||
"accountFilter": gin.H{
|
"accountFilter": gin.H{
|
||||||
"instanceId": "",
|
"instanceId": "",
|
||||||
},
|
},
|
||||||
"notifyHook": c.Config.NotifyURL,
|
"state": "",
|
||||||
"state": "",
|
|
||||||
}
|
}
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
var errRes ErrRes
|
var errRes ErrRes
|
||||||
@ -183,11 +148,10 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Upscale 放大指定的图片
|
// Upscale 放大指定的图片
|
||||||
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
|
||||||
body := map[string]string{
|
body := map[string]string{
|
||||||
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
|
||||||
"taskId": task.MessageId,
|
"taskId": task.MessageId,
|
||||||
"notifyHook": c.Config.NotifyURL,
|
|
||||||
}
|
}
|
||||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
@ -210,11 +174,10 @@ func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||||
body := map[string]string{
|
body := map[string]string{
|
||||||
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
|
||||||
"taskId": task.MessageId,
|
"taskId": task.MessageId,
|
||||||
"notifyHook": c.Config.NotifyURL,
|
|
||||||
}
|
}
|
||||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
@ -236,32 +199,7 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryRes struct {
|
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
|
||||||
Action string `json:"action"`
|
|
||||||
Buttons []struct {
|
|
||||||
CustomId string `json:"customId"`
|
|
||||||
Emoji string `json:"emoji"`
|
|
||||||
Label string `json:"label"`
|
|
||||||
Style int `json:"style"`
|
|
||||||
Type int `json:"type"`
|
|
||||||
} `json:"buttons"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
FailReason string `json:"failReason"`
|
|
||||||
FinishTime int `json:"finishTime"`
|
|
||||||
Id string `json:"id"`
|
|
||||||
ImageUrl string `json:"imageUrl"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
PromptEn string `json:"promptEn"`
|
|
||||||
Properties struct {
|
|
||||||
} `json:"properties"`
|
|
||||||
StartTime int `json:"startTime"`
|
|
||||||
State string `json:"state"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
SubmitTime int `json:"submitTime"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
||||||
var res QueryRes
|
var res QueryRes
|
||||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||||
@ -278,3 +216,5 @@ func (c *Client) QueryTask(taskId string) (QueryRes, error) {
|
|||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Client = &PlusClient{}
|
@ -2,13 +2,12 @@ package mj
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service/mj/plus"
|
logger2 "chatplus/logger"
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -16,7 +15,7 @@ import (
|
|||||||
|
|
||||||
// ServicePool Mj service pool
|
// ServicePool Mj service pool
|
||||||
type ServicePool struct {
|
type ServicePool struct {
|
||||||
services []interface{}
|
services []*Service
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
@ -24,8 +23,10 @@ type ServicePool struct {
|
|||||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||||
services := make([]interface{}, 0)
|
services := make([]*Service, 0)
|
||||||
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
||||||
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
||||||
|
|
||||||
@ -33,45 +34,26 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
|||||||
if config.Enabled == false {
|
if config.Enabled == false {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client := plus.NewClient(config)
|
cli := NewPlusClient(config)
|
||||||
name := fmt.Sprintf("mj-service-plus-%d", k)
|
name := fmt.Sprintf("mj-plus-service-%d", k)
|
||||||
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
|
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||||
go func() {
|
go func() {
|
||||||
servicePlus.Run()
|
service.Run()
|
||||||
}()
|
}()
|
||||||
services = append(services, servicePlus)
|
services = append(services, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(services) == 0 {
|
for k, config := range appConfig.MjProxyConfigs {
|
||||||
// create mj client and service
|
if config.Enabled == false {
|
||||||
for k, config := range appConfig.MjConfigs {
|
continue
|
||||||
if config.Enabled == false {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// create mj client
|
|
||||||
client := NewClient(config, appConfig.ProxyURL)
|
|
||||||
|
|
||||||
name := fmt.Sprintf("MjService-%d", k)
|
|
||||||
// create mj service
|
|
||||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
|
|
||||||
botName := fmt.Sprintf("MjBot-%d", k)
|
|
||||||
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = bot.Run()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// run mj service
|
|
||||||
go func() {
|
|
||||||
service.Run()
|
|
||||||
}()
|
|
||||||
|
|
||||||
services = append(services, service)
|
|
||||||
}
|
}
|
||||||
|
cli := NewProxyClient(config)
|
||||||
|
name := fmt.Sprintf("mj-proxy-service-%d", k)
|
||||||
|
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli)
|
||||||
|
go func() {
|
||||||
|
service.Run()
|
||||||
|
}()
|
||||||
|
services = append(services, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ServicePool{
|
return &ServicePool{
|
||||||
@ -92,11 +74,11 @@ func (p *ServicePool) CheckTaskNotify() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client := p.Clients.Get(userId)
|
cli := p.Clients.Get(userId)
|
||||||
if client == nil {
|
if cli == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = client.Send([]byte("Task Updated"))
|
err = cli.Send([]byte("Task Updated"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -122,10 +104,10 @@ func (p *ServicePool) DownloadImages() {
|
|||||||
logger.Infof("try to download image: %s", v.OrgURL)
|
logger.Infof("try to download image: %s", v.OrgURL)
|
||||||
var imgURL string
|
var imgURL string
|
||||||
var err error
|
var err error
|
||||||
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
|
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
|
||||||
task, _ := servicePlus.Client.QueryTask(v.TaskId)
|
task, _ := servicePlus.Client.QueryTask(v.TaskId)
|
||||||
if len(task.Buttons) > 0 {
|
if len(task.Buttons) > 0 {
|
||||||
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
|
v.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||||
}
|
}
|
||||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
|
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
|
||||||
} else {
|
} else {
|
||||||
@ -141,11 +123,11 @@ func (p *ServicePool) DownloadImages() {
|
|||||||
v.ImgURL = imgURL
|
v.ImgURL = imgURL
|
||||||
p.db.Updates(&v)
|
p.db.Updates(&v)
|
||||||
|
|
||||||
client := p.Clients.Get(uint(v.UserId))
|
cli := p.Clients.Get(uint(v.UserId))
|
||||||
if client == nil {
|
if cli == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = client.Send([]byte("Task Updated"))
|
err = cli.Send([]byte("Task Updated"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -167,25 +149,6 @@ func (p *ServicePool) HasAvailableService() bool {
|
|||||||
return len(p.services) > 0
|
return len(p.services) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ServicePool) Notify(data plus.CBReq) error {
|
|
||||||
logger.Debugf("收到任务回调:%+v", data)
|
|
||||||
var job model.MidJourneyJob
|
|
||||||
res := p.db.Where("task_id = ?", data.Id).First(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
return fmt.Errorf("非法任务:%s", data.Id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 任务已经拉取完成
|
|
||||||
if job.Progress == 100 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
|
|
||||||
return servicePlus.Notify(job)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncTaskProgress 异步拉取任务
|
// SyncTaskProgress 异步拉取任务
|
||||||
func (p *ServicePool) SyncTaskProgress() {
|
func (p *ServicePool) SyncTaskProgress() {
|
||||||
go func() {
|
go func() {
|
||||||
@ -222,11 +185,7 @@ func (p *ServicePool) SyncTaskProgress() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(job.ChannelId, "mj-service-plus") {
|
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
|
|
||||||
_ = servicePlus.Notify(job)
|
_ = servicePlus.Notify(job)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -236,12 +195,10 @@ func (p *ServicePool) SyncTaskProgress() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ServicePool) getServicePlus(name string) *plus.Service {
|
func (p *ServicePool) getService(name string) *Service {
|
||||||
for _, s := range p.services {
|
for _, s := range p.services {
|
||||||
if servicePlus, ok := s.(*plus.Service); ok {
|
if s.Name == name {
|
||||||
if servicePlus.Name == name {
|
return s
|
||||||
return servicePlus
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
176
api/service/mj/proxy_client.go
Normal file
176
api/service/mj/proxy_client.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/utils"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyClient MidJourney Proxy Client
|
||||||
|
type ProxyClient struct {
|
||||||
|
Config types.MjProxyConfig
|
||||||
|
apiURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
|
||||||
|
return &ProxyClient{Config: config, apiURL: config.ApiURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
|
||||||
|
body := ImageReq{
|
||||||
|
Prompt: task.Prompt,
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
logger.Info("API URL: ", apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
all, err := io.ReadAll(r.Body)
|
||||||
|
logger.Info(string(all))
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
errStr, _ := io.ReadAll(r.Body)
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Blend 融图
|
||||||
|
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
|
||||||
|
body := ImageReq{
|
||||||
|
Dimensions: "SQUARE",
|
||||||
|
Base64Array: make([]string, 0),
|
||||||
|
}
|
||||||
|
// 生成图片 Base64 编码
|
||||||
|
if len(task.ImgArr) > 0 {
|
||||||
|
for _, imgURL := range task.ImgArr {
|
||||||
|
imageData, err := utils.DownloadImage(imgURL, "")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
} else {
|
||||||
|
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwapFace 换脸
|
||||||
|
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
|
||||||
|
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upscale 放大指定的图片
|
||||||
|
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"action": "UPSCALE",
|
||||||
|
"index": task.Index,
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
|
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"action": "VARIATION",
|
||||||
|
"index": task.Index,
|
||||||
|
"taskId": task.MessageId,
|
||||||
|
}
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
|
||||||
|
var res ImageRes
|
||||||
|
var errRes ErrRes
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
Post(apiURL)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
|
||||||
|
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
||||||
|
var res QueryRes
|
||||||
|
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
Get(apiURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return QueryRes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.IsErrorState() {
|
||||||
|
return QueryRes{}, errors.New("error status:" + r.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Client = &ProxyClient{}
|
@ -16,24 +16,24 @@ import (
|
|||||||
|
|
||||||
// Service MJ 绘画服务
|
// Service MJ 绘画服务
|
||||||
type Service struct {
|
type Service struct {
|
||||||
name string // service name
|
Name string // service Name
|
||||||
client *Client // MJ client
|
Client Client // MJ Client
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
notifyQueue *store.RedisQueue
|
notifyQueue *store.RedisQueue
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
maxHandleTaskNum int32 // max task number current service can handle
|
maxHandleTaskNum int32 // max task number current service can handle
|
||||||
handledTaskNum int32 // already handled task number
|
HandledTaskNum int32 // already handled task number
|
||||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||||
taskTimeout int64
|
taskTimeout int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
Name: name,
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: taskQueue,
|
taskQueue: taskQueue,
|
||||||
notifyQueue: notifyQueue,
|
notifyQueue: notifyQueue,
|
||||||
client: client,
|
Client: cli,
|
||||||
taskTimeout: timeout,
|
taskTimeout: timeout,
|
||||||
maxHandleTaskNum: maxTaskNum,
|
maxHandleTaskNum: maxTaskNum,
|
||||||
taskStartTimes: make(map[int]time.Time, 0),
|
taskStartTimes: make(map[int]time.Time, 0),
|
||||||
@ -41,7 +41,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Infof("Starting MidJourney job consumer for %s", s.name)
|
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
||||||
for {
|
for {
|
||||||
s.checkTasks()
|
s.checkTasks()
|
||||||
if !s.canHandleTask() {
|
if !s.canHandleTask() {
|
||||||
@ -58,65 +58,72 @@ func (s *Service) Run() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it's reference message, check if it's this channel's message
|
// 如果配置了多个中转平台的 API KEY
|
||||||
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
|
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
|
||||||
|
if task.ChannelId != "" && task.ChannelId != s.Name {
|
||||||
|
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
|
||||||
s.taskQueue.RPush(task)
|
s.taskQueue.RPush(task)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 翻译提示词
|
// 翻译提示词
|
||||||
if utils.HasChinese(task.Prompt) {
|
if utils.HasChinese(task.Prompt) && strings.HasPrefix(s.Name, "mj-proxy-service") {
|
||||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
|
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
task.Prompt = content
|
task.Prompt = content
|
||||||
|
} else {
|
||||||
|
logger.Warnf("error with translate prompt: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
|
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
|
||||||
|
var res ImageRes
|
||||||
switch task.Type {
|
switch task.Type {
|
||||||
case types.TaskImage:
|
case types.TaskImage:
|
||||||
err = s.client.Imagine(task)
|
res, err = s.Client.Imagine(task)
|
||||||
break
|
break
|
||||||
case types.TaskUpscale:
|
case types.TaskUpscale:
|
||||||
err = s.client.Upscale(task)
|
res, err = s.Client.Upscale(task)
|
||||||
break
|
break
|
||||||
case types.TaskVariation:
|
case types.TaskVariation:
|
||||||
err = s.client.Variation(task)
|
res, err = s.Client.Variation(task)
|
||||||
break
|
break
|
||||||
case types.TaskBlend:
|
case types.TaskBlend:
|
||||||
err = s.client.Blend(task)
|
res, err = s.Client.Blend(task)
|
||||||
break
|
break
|
||||||
case types.TaskSwapFace:
|
case types.TaskSwapFace:
|
||||||
err = s.client.SwapFace(task)
|
res, err = s.Client.SwapFace(task)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
var job model.MidJourneyJob
|
||||||
logger.Error("绘画任务执行失败:", err.Error())
|
s.db.Where("id = ?", task.Id).First(&job)
|
||||||
|
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||||
|
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
|
||||||
|
logger.Error("绘画任务执行失败:", errMsg)
|
||||||
|
job.Progress = -1
|
||||||
|
job.ErrMsg = errMsg
|
||||||
// update the task progress
|
// update the task progress
|
||||||
s.db.Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
s.db.Updates(&job)
|
||||||
"progress": -1,
|
// 任务失败,通知前端
|
||||||
"err_msg": err.Error(),
|
|
||||||
})
|
|
||||||
s.notifyQueue.RPush(task.UserId)
|
s.notifyQueue.RPush(task.UserId)
|
||||||
// restore img_call quota
|
|
||||||
if task.Type.String() != types.TaskUpscale.String() {
|
|
||||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logger.Infof("Task Executed: %+v", task)
|
logger.Infof("任务提交成功:%+v", res)
|
||||||
// lock the task until the execute timeout
|
// lock the task until the execute timeout
|
||||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
s.taskStartTimes[int(task.Id)] = time.Now()
|
||||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
atomic.AddInt32(&s.HandledTaskNum, 1)
|
||||||
|
// 更新任务 ID/频道
|
||||||
|
job.TaskId = res.Result
|
||||||
|
job.ChannelId = s.Name
|
||||||
|
s.db.Updates(&job)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if current service instance can handle more task
|
// check if current service instance can handle more task
|
||||||
func (s *Service) canHandleTask() bool {
|
func (s *Service) canHandleTask() bool {
|
||||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
handledNum := atomic.LoadInt32(&s.HandledTaskNum)
|
||||||
return handledNum < s.maxHandleTaskNum
|
return handledNum < s.maxHandleTaskNum
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,65 +132,75 @@ func (s *Service) checkTasks() {
|
|||||||
for k, t := range s.taskStartTimes {
|
for k, t := range s.taskStartTimes {
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||||
delete(s.taskStartTimes, k)
|
delete(s.taskStartTimes, k)
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||||
// delete task from database
|
// delete task from database
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Notify(data CBReq) {
|
type CBReq struct {
|
||||||
// extract the task ID
|
Id string `json:"id"`
|
||||||
split := strings.Split(data.Prompt, " ")
|
Action string `json:"action"`
|
||||||
var job model.MidJourneyJob
|
Status string `json:"status"`
|
||||||
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
Prompt string `json:"prompt"`
|
||||||
if res.Error == nil && data.Status == Finished {
|
PromptEn string `json:"promptEn"`
|
||||||
logger.Warn("重复消息:", data.MessageId)
|
Description string `json:"description"`
|
||||||
return
|
SubmitTime int64 `json:"submitTime"`
|
||||||
}
|
StartTime int64 `json:"startTime"`
|
||||||
|
FinishTime int64 `json:"finishTime"`
|
||||||
tx := s.db.Session(&gorm.Session{}).Where("progress < ?", 100).Order("id ASC")
|
Progress string `json:"progress"`
|
||||||
if data.ReferenceId != "" {
|
ImageUrl string `json:"imageUrl"`
|
||||||
tx = tx.Where("reference_id = ?", data.ReferenceId)
|
FailReason interface{} `json:"failReason"`
|
||||||
} else {
|
Properties struct {
|
||||||
tx = tx.Where("task_id = ?", split[0])
|
FinalPrompt string `json:"finalPrompt"`
|
||||||
}
|
} `json:"properties"`
|
||||||
// fixed: 修复 U/V 操作任务混淆覆盖的 Bug
|
}
|
||||||
if strings.Contains(data.Prompt, "** - Image #") { // for upscale
|
|
||||||
tx = tx.Where("type = ?", types.TaskUpscale.String())
|
func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||||
} else if strings.Contains(data.Prompt, "** - Variations (Strong)") { // for Variations
|
task, err := s.Client.QueryTask(job.TaskId)
|
||||||
tx = tx.Where("type = ?", types.TaskVariation.String())
|
if err != nil {
|
||||||
}
|
return err
|
||||||
res = tx.First(&job)
|
}
|
||||||
if res.Error != nil {
|
|
||||||
logger.Warn("非法任务:", res.Error)
|
// 任务执行失败了
|
||||||
return
|
if task.FailReason != "" {
|
||||||
}
|
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||||
|
"progress": -1,
|
||||||
job.ChannelId = data.ChannelId
|
"err_msg": task.FailReason,
|
||||||
job.MessageId = data.MessageId
|
})
|
||||||
job.ReferenceId = data.ReferenceId
|
return fmt.Errorf("task failed: %v", task.FailReason)
|
||||||
job.Progress = data.Progress
|
}
|
||||||
job.Prompt = data.Prompt
|
|
||||||
job.Hash = data.Image.Hash
|
if len(task.Buttons) > 0 {
|
||||||
if s.client.Config.UseCDN {
|
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||||
job.UseProxy = true
|
}
|
||||||
job.OrgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL)
|
oldProgress := job.Progress
|
||||||
} else {
|
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||||
job.OrgURL = data.Image.URL
|
job.Prompt = task.PromptEn
|
||||||
}
|
if task.ImageUrl != "" {
|
||||||
|
job.OrgURL = task.ImageUrl
|
||||||
res = s.db.Updates(&job)
|
}
|
||||||
if res.Error != nil {
|
job.MessageId = task.Id
|
||||||
logger.Error("error with update job: ", res.Error)
|
tx := s.db.Updates(&job)
|
||||||
return
|
if tx.Error != nil {
|
||||||
}
|
return fmt.Errorf("error with update database: %v", tx.Error)
|
||||||
|
}
|
||||||
if data.Status == Finished {
|
if task.Status == "SUCCESS" {
|
||||||
// release lock task
|
// release lock task
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||||
}
|
}
|
||||||
|
// 通知前端更新任务进度
|
||||||
s.notifyQueue.RPush(job.UserId)
|
if oldProgress != job.Progress {
|
||||||
|
s.notifyQueue.RPush(job.UserId)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetImageHash(action string) string {
|
||||||
|
split := strings.Split(action, "::")
|
||||||
|
if len(split) > 5 {
|
||||||
|
return split[4]
|
||||||
|
}
|
||||||
|
return split[len(split)-1]
|
||||||
}
|
}
|
||||||
|
@ -1,35 +0,0 @@
|
|||||||
package mj
|
|
||||||
|
|
||||||
const (
|
|
||||||
ApplicationID string = "936929561302675456"
|
|
||||||
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
|
|
||||||
)
|
|
||||||
|
|
||||||
type InteractionsRequest struct {
|
|
||||||
Type int `json:"type"`
|
|
||||||
ApplicationID string `json:"application_id"`
|
|
||||||
MessageFlags int `json:"message_flags,omitempty"`
|
|
||||||
MessageID string `json:"message_id,omitempty"`
|
|
||||||
GuildID string `json:"guild_id"`
|
|
||||||
ChannelID string `json:"channel_id"`
|
|
||||||
SessionID string `json:"session_id"`
|
|
||||||
Data map[string]any `json:"data"`
|
|
||||||
Nonce string `json:"nonce,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type InteractionsResult struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string
|
|
||||||
Error map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
type CBReq struct {
|
|
||||||
ChannelId string `json:"channel_id"`
|
|
||||||
MessageId string `json:"message_id"`
|
|
||||||
ReferenceId string `json:"reference_id"`
|
|
||||||
Image Image `json:"image"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Status TaskStatus `json:"status"`
|
|
||||||
Progress int `json:"progress"`
|
|
||||||
}
|
|
@ -1,5 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
func main() {
|
import (
|
||||||
|
"chatplus/utils"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
text := "一只 蜗牛在树干上爬,阳光透过树叶照在蜗牛的背上 --ar 1:1 --iw 0.250000 --v 6"
|
||||||
|
fmt.Println(utils.HasChinese(text))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user