mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 18:53:43 +08:00
feat: midjourney plus service is ready
This commit is contained in:
@@ -33,7 +33,7 @@ func NewBot(name string, proxy string, config types.MidJourneyConfig, service *S
|
||||
// use CDN reverse proxy
|
||||
if config.UseCDN {
|
||||
discordgo.SetEndpointDiscord(config.DiscordAPI)
|
||||
discordgo.SetEndpointCDN(config.DiscordCDN)
|
||||
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
|
||||
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
|
||||
bot.MjGateway = config.DiscordGateway + "/"
|
||||
} else { // use proxy
|
||||
|
||||
@@ -11,12 +11,13 @@ import (
|
||||
// MidJourney client
|
||||
|
||||
type Client struct {
|
||||
client *req.Client
|
||||
Config types.MidJourneyConfig
|
||||
apiURL string
|
||||
client *req.Client
|
||||
Config types.MidJourneyConfig
|
||||
imgCdnURL string
|
||||
apiURL string
|
||||
}
|
||||
|
||||
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
|
||||
func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *Client {
|
||||
client := req.C().SetTimeout(10 * time.Second)
|
||||
var apiURL string
|
||||
// set proxy URL
|
||||
@@ -29,7 +30,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{client: client, Config: config, apiURL: apiURL}
|
||||
return &Client{client: client, Config: config, apiURL: apiURL, imgCdnURL: imgCdnURL}
|
||||
}
|
||||
|
||||
func (c *Client) Imagine(prompt string) error {
|
||||
|
||||
171
api/service/mj/plus/client.go
Normal file
171
api/service/mj/plus/client.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package plus
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// Client MidJourney Plus Client
|
||||
type Client struct {
|
||||
Config types.MidJourneyPlusConfig
|
||||
}
|
||||
|
||||
func NewClient(config types.MidJourneyPlusConfig) *Client {
|
||||
return &Client{Config: config}
|
||||
}
|
||||
|
||||
type ImageReq struct {
|
||||
BotType string `json:"botType"`
|
||||
Prompt string `json:"prompt"`
|
||||
Base64Array []interface{} `json:"base64Array,omitempty"`
|
||||
AccountFilter struct {
|
||||
InstanceId string `json:"instanceId"`
|
||||
Modes []interface{} `json:"modes"`
|
||||
Remix bool `json:"remix"`
|
||||
RemixAutoConsidered bool `json:"remixAutoConsidered"`
|
||||
} `json:"accountFilter,omitempty"`
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
State string `json:"state,omitempty"`
|
||||
}
|
||||
|
||||
type ImageRes struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Properties struct {
|
||||
} `json:"properties"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
type ErrRes struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func (c *Client) Imagine(prompt string) (ImageRes, error) {
|
||||
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.Config.ApiURL)
|
||||
body := ImageReq{
|
||||
BotType: "MID_JOURNEY",
|
||||
Prompt: prompt,
|
||||
NotifyHook: c.Config.NotifyURL,
|
||||
}
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Upscale 放大指定的图片
|
||||
func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, error) {
|
||||
body := map[string]string{
|
||||
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
|
||||
"taskId": messageId,
|
||||
"notifyHook": c.Config.NotifyURL,
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||
func (c *Client) Variation(index int, messageId string, hash string) (ImageRes, error) {
|
||||
body := map[string]string{
|
||||
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
|
||||
"taskId": messageId,
|
||||
"notifyHook": c.Config.NotifyURL,
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
type QueryRes struct {
|
||||
Action string `json:"action"`
|
||||
Buttons []struct {
|
||||
CustomId string `json:"customId"`
|
||||
Emoji string `json:"emoji"`
|
||||
Label string `json:"label"`
|
||||
Style int `json:"style"`
|
||||
Type int `json:"type"`
|
||||
} `json:"buttons"`
|
||||
Description string `json:"description"`
|
||||
FailReason string `json:"failReason"`
|
||||
FinishTime int `json:"finishTime"`
|
||||
Id string `json:"id"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Progress string `json:"progress"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Properties struct {
|
||||
} `json:"properties"`
|
||||
StartTime int `json:"startTime"`
|
||||
State string `json:"state"`
|
||||
Status string `json:"status"`
|
||||
SubmitTime int `json:"submitTime"`
|
||||
}
|
||||
|
||||
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
|
||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.Config.ApiURL, taskId)
|
||||
var res QueryRes
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
|
||||
if err != nil {
|
||||
return QueryRes{}, err
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return QueryRes{}, errors.New("error status:" + r.Status)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
164
api/service/mj/plus/service.go
Normal file
164
api/service/mj/plus/service.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package plus
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Service MJ 绘画服务
|
||||
type Service struct {
|
||||
name string // service name
|
||||
Client *Client // MJ Client
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
maxHandleTaskNum int32 // max task number current service can handle
|
||||
handledTaskNum int32 // already handled task number
|
||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||
taskTimeout int64
|
||||
}
|
||||
|
||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
||||
return &Service{
|
||||
name: name,
|
||||
db: db,
|
||||
taskQueue: taskQueue,
|
||||
notifyQueue: notifyQueue,
|
||||
Client: client,
|
||||
taskTimeout: timeout,
|
||||
maxHandleTaskNum: maxTaskNum,
|
||||
taskStartTimes: make(map[int]time.Time, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Infof("Starting MidJourney job consumer for %s", s.name)
|
||||
for {
|
||||
s.checkTasks()
|
||||
if !s.canHandleTask() {
|
||||
// current service is full, can not handle more task
|
||||
// waiting for running task finish
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
var task types.MjTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// if it's reference message, check if it's this channel's message
|
||||
if task.ChannelId != "" && task.ChannelId != s.Client.Config.Name {
|
||||
s.taskQueue.RPush(task)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
|
||||
var res ImageRes
|
||||
switch task.Type {
|
||||
case types.TaskImage:
|
||||
index := strings.Index(task.Prompt, " ")
|
||||
res, err = s.Client.Imagine(task.Prompt[index+1:])
|
||||
break
|
||||
case types.TaskUpscale:
|
||||
res, err = s.Client.Upscale(task.Index, task.MessageId, task.MessageHash)
|
||||
break
|
||||
case types.TaskVariation:
|
||||
res, err = s.Client.Variation(task.Index, task.MessageId, task.MessageHash)
|
||||
}
|
||||
|
||||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
// update the task progress
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
// 任务失败,通知前端
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
// restore img_call quota
|
||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
||||
|
||||
// TODO: 任务提交失败,加入队列重试
|
||||
continue
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[task.Id] = time.Now()
|
||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
||||
// 更新任务 ID/频道
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{
|
||||
"task_id": res.Result,
|
||||
"channel_id": s.Client.Config.Name,
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// check if current service instance can handle more task
|
||||
func (s *Service) canHandleTask() bool {
|
||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
||||
return handledNum < s.maxHandleTaskNum
|
||||
}
|
||||
|
||||
// remove the expired tasks
|
||||
func (s *Service) checkTasks() {
|
||||
for k, t := range s.taskStartTimes {
|
||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||
delete(s.taskStartTimes, k)
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
// delete task from database
|
||||
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
Id string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
Status string `json:"status"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
Progress string `json:"progress"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
FailReason interface{} `json:"failReason"`
|
||||
Properties struct {
|
||||
FinalPrompt string `json:"finalPrompt"`
|
||||
} `json:"properties"`
|
||||
}
|
||||
|
||||
func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error {
|
||||
|
||||
job.Progress = utils.IntValue(strings.Replace(data.Progress, "%", "", 1), 0)
|
||||
job.Prompt = data.Properties.FinalPrompt
|
||||
if data.ImageUrl != "" {
|
||||
job.OrgURL = data.ImageUrl
|
||||
}
|
||||
job.UseProxy = true
|
||||
job.MessageId = data.Id
|
||||
logger.Debugf("JOB: %+v", job)
|
||||
res := s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with update job: %v", res.Error)
|
||||
}
|
||||
|
||||
if data.Status == "SUCCESS" {
|
||||
// release lock task
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
return nil
|
||||
}
|
||||
@@ -2,11 +2,13 @@ package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/mj/plus"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -14,7 +16,7 @@ import (
|
||||
|
||||
// ServicePool Mj service pool
|
||||
type ServicePool struct {
|
||||
services []*Service
|
||||
services []interface{}
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
@@ -23,37 +25,53 @@ type ServicePool struct {
|
||||
}
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||
services := make([]*Service, 0)
|
||||
services := make([]interface{}, 0)
|
||||
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
||||
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
|
||||
// create mj client and service
|
||||
for k, config := range appConfig.MjConfigs {
|
||||
|
||||
for k, config := range appConfig.MjPlusConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
// create mj client
|
||||
client := NewClient(config, appConfig.ProxyURL)
|
||||
|
||||
name := fmt.Sprintf("MjService-%d", k)
|
||||
// create mj service
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
|
||||
botName := fmt.Sprintf("MjBot-%d", k)
|
||||
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err = bot.Run()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// run mj service
|
||||
client := plus.NewClient(config)
|
||||
name := fmt.Sprintf("MidJourney Plus Service-%d", k)
|
||||
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
|
||||
go func() {
|
||||
service.Run()
|
||||
servicePlus.Run()
|
||||
}()
|
||||
services = append(services, servicePlus)
|
||||
}
|
||||
|
||||
services = append(services, service)
|
||||
if len(services) == 0 {
|
||||
// create mj client and service
|
||||
for k, config := range appConfig.MjConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
// create mj client
|
||||
client := NewClient(config, appConfig.ProxyURL, appConfig.ImgCdnURL)
|
||||
|
||||
name := fmt.Sprintf("MjService-%d", k)
|
||||
// create mj service
|
||||
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
|
||||
botName := fmt.Sprintf("MjBot-%d", k)
|
||||
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err = bot.Run()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// run mj service
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
|
||||
services = append(services, service)
|
||||
}
|
||||
}
|
||||
|
||||
return &ServicePool{
|
||||
@@ -94,7 +112,24 @@ func (p *ServicePool) DownloadImages() {
|
||||
|
||||
// download images
|
||||
for _, v := range items {
|
||||
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
|
||||
if v.OrgURL == "" {
|
||||
continue
|
||||
}
|
||||
var imgURL string
|
||||
var err error
|
||||
if v.UseProxy {
|
||||
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
|
||||
task, _ := servicePlus.Client.QueryTask(v.TaskId)
|
||||
if task.ImageUrl != "" {
|
||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(task.ImageUrl, false)
|
||||
}
|
||||
if len(task.Buttons) > 0 {
|
||||
v.Hash = getImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
continue
|
||||
@@ -125,3 +160,37 @@ func (p *ServicePool) PushTask(task types.MjTask) {
|
||||
func (p *ServicePool) HasAvailableService() bool {
|
||||
return len(p.services) > 0
|
||||
}
|
||||
|
||||
func (p *ServicePool) Notify(data plus.CBReq) error {
|
||||
logger.Infof("收到任务回调:%+v", data)
|
||||
var job model.MidJourneyJob
|
||||
res := p.db.Where("task_id = ?", data.Id).First(&job)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("非法任务:%s", data.Id)
|
||||
}
|
||||
|
||||
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
|
||||
return servicePlus.Notify(data, job)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ServicePool) getServicePlus(name string) *plus.Service {
|
||||
for _, s := range p.services {
|
||||
if servicePlus, ok := s.(*plus.Service); ok {
|
||||
if servicePlus.Client.Config.Name == name {
|
||||
return servicePlus
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getImageHash(action string) string {
|
||||
split := strings.Split(action, "::")
|
||||
if len(split) > 5 {
|
||||
return split[4]
|
||||
}
|
||||
return split[len(split)-1]
|
||||
}
|
||||
|
||||
@@ -58,8 +58,6 @@ func (s *Service) Run() {
|
||||
// if it's reference message, check if it's this channel's message
|
||||
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
|
||||
s.taskQueue.RPush(task)
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
@@ -143,7 +141,7 @@ func (s *Service) Notify(data CBReq) {
|
||||
job.OrgURL = data.Image.URL
|
||||
if s.client.Config.UseCDN {
|
||||
job.UseProxy = true
|
||||
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN)
|
||||
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.imgCdnURL)
|
||||
}
|
||||
|
||||
res = s.db.Updates(&job)
|
||||
|
||||
Reference in New Issue
Block a user