mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 10:43:44 +08:00
opt: enable use cdn url for mj-plus
This commit is contained in:
@@ -12,13 +12,12 @@ import (
|
||||
// MidJourney client
|
||||
|
||||
type Client struct {
|
||||
client *req.Client
|
||||
Config types.MidJourneyConfig
|
||||
imgCdnURL string
|
||||
apiURL string
|
||||
client *req.Client
|
||||
Config types.MidJourneyConfig
|
||||
apiURL string
|
||||
}
|
||||
|
||||
func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *Client {
|
||||
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
|
||||
client := req.C().SetTimeout(10 * time.Second)
|
||||
var apiURL string
|
||||
// set proxy URL
|
||||
@@ -31,7 +30,7 @@ func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *C
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{client: client, Config: config, apiURL: apiURL, imgCdnURL: imgCdnURL}
|
||||
return &Client{client: client, Config: config, apiURL: apiURL}
|
||||
}
|
||||
|
||||
func (c *Client) Imagine(task types.MjTask) error {
|
||||
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
@@ -18,10 +19,17 @@ var logger = logger2.GetLogger()
|
||||
// Client MidJourney Plus Client
|
||||
type Client struct {
|
||||
Config types.MidJourneyPlusConfig
|
||||
apiURL string
|
||||
}
|
||||
|
||||
func NewClient(config types.MidJourneyPlusConfig) *Client {
|
||||
return &Client{Config: config}
|
||||
var apiURL string
|
||||
if config.CdnURL != "" {
|
||||
apiURL = config.CdnURL
|
||||
} else {
|
||||
apiURL = config.ApiURL
|
||||
}
|
||||
return &Client{Config: config, apiURL: apiURL}
|
||||
}
|
||||
|
||||
type ImageReq struct {
|
||||
@@ -54,12 +62,12 @@ type ErrRes struct {
|
||||
}
|
||||
|
||||
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.Config.ApiURL)
|
||||
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.apiURL)
|
||||
body := ImageReq{
|
||||
BotType: "MID_JOURNEY",
|
||||
Prompt: task.Prompt,
|
||||
NotifyHook: c.Config.NotifyURL,
|
||||
Base64Array: make([]string, 1),
|
||||
Base64Array: make([]string, 0),
|
||||
}
|
||||
// 生成图片 Base64 编码
|
||||
if len(task.ImgArr) > 0 {
|
||||
@@ -67,7 +75,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
} else {
|
||||
body.Base64Array[0] = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
|
||||
}
|
||||
|
||||
}
|
||||
@@ -80,12 +88,12 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
SetErrorResult(&errRes).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errStr, _ := io.ReadAll(r.Body)
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v,%v", err, string(errStr))
|
||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
||||
errStr, _ := io.ReadAll(r.Body)
|
||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
@@ -93,7 +101,7 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
||||
|
||||
// Blend 融图
|
||||
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
||||
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/blend", c.Config.ApiURL)
|
||||
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/blend", c.apiURL)
|
||||
body := ImageReq{
|
||||
BotType: "MID_JOURNEY",
|
||||
Dimensions: "SQUARE",
|
||||
@@ -133,7 +141,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
||||
|
||||
// SwapFace 换脸
|
||||
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
||||
apiURL := fmt.Sprintf("%s/mj-fast/mj/insight-face/swap", c.Config.ApiURL)
|
||||
apiURL := fmt.Sprintf("%s/mj-fast/mj/insight-face/swap", c.apiURL)
|
||||
// 生成图片 Base64 编码
|
||||
if len(task.ImgArr) != 2 {
|
||||
return ImageRes{}, errors.New("参数错误,必须上传2张图片")
|
||||
@@ -189,7 +197,7 @@ func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
|
||||
"taskId": task.MessageId,
|
||||
"notifyHook": c.Config.NotifyURL,
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
@@ -216,7 +224,7 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
|
||||
"taskId": task.MessageId,
|
||||
"notifyHook": c.Config.NotifyURL,
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
|
||||
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
|
||||
var res ImageRes
|
||||
var errRes ErrRes
|
||||
r, err := req.C().R().
|
||||
@@ -262,7 +270,7 @@ type QueryRes struct {
|
||||
}
|
||||
|
||||
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
|
||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.Config.ApiURL, taskId)
|
||||
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
|
||||
var res QueryRes
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
|
||||
SetSuccessResult(&res).
|
||||
|
||||
@@ -86,10 +86,10 @@ func (s *Service) Run() {
|
||||
}
|
||||
|
||||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||||
errMsg := err.Error() + res.Description
|
||||
errMsg := fmt.Sprintf("%v,%s", err, res.Description)
|
||||
logger.Error("绘画任务执行失败:", errMsg)
|
||||
// update the task progress
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
|
||||
s.db.Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": errMsg,
|
||||
})
|
||||
@@ -105,10 +105,10 @@ func (s *Service) Run() {
|
||||
}
|
||||
logger.Infof("任务提交成功:%+v", res)
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[task.Id] = time.Now()
|
||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
||||
atomic.AddInt32(&s.HandledTaskNum, 1)
|
||||
// 更新任务 ID/频道
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{
|
||||
s.db.Debug().Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": res.Result,
|
||||
"channel_id": s.Name,
|
||||
})
|
||||
@@ -152,26 +152,55 @@ type CBReq struct {
|
||||
} `json:"properties"`
|
||||
}
|
||||
|
||||
func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error {
|
||||
func (s *Service) Notify(job model.MidJourneyJob) error {
|
||||
task, err := s.Client.QueryTask(job.TaskId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
job.Progress = utils.IntValue(strings.Replace(data.Progress, "%", "", 1), 0)
|
||||
job.Prompt = data.Properties.FinalPrompt
|
||||
if data.ImageUrl != "" {
|
||||
job.OrgURL = data.ImageUrl
|
||||
// 任务执行失败了
|
||||
if task.FailReason != "" {
|
||||
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
return fmt.Errorf("task failed: %v", task.FailReason)
|
||||
}
|
||||
|
||||
if len(task.Buttons) > 0 {
|
||||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
oldProgress := job.Progress
|
||||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||
job.Prompt = task.PromptEn
|
||||
if task.ImageUrl != "" {
|
||||
if s.Client.Config.CdnURL != "" {
|
||||
job.OrgURL = strings.Replace(task.ImageUrl, s.Client.Config.ApiURL, s.Client.Config.CdnURL, 1)
|
||||
} else {
|
||||
job.OrgURL = task.ImageUrl
|
||||
}
|
||||
}
|
||||
job.UseProxy = true
|
||||
job.MessageId = data.Id
|
||||
logger.Debugf("JOB: %+v", job)
|
||||
res := s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("error with update job: %v", res.Error)
|
||||
job.MessageId = task.Id
|
||||
tx := s.db.Updates(&job)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("error with update database: %v", tx.Error)
|
||||
}
|
||||
|
||||
if data.Status == "SUCCESS" {
|
||||
if task.Status == "SUCCESS" {
|
||||
// release lock task
|
||||
atomic.AddInt32(&s.HandledTaskNum, -1)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != job.Progress {
|
||||
s.notifyQueue.RPush(job.UserId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetImageHash(action string) string {
|
||||
split := strings.Split(action, "::")
|
||||
if len(split) > 5 {
|
||||
return split[4]
|
||||
}
|
||||
return split[len(split)-1]
|
||||
}
|
||||
|
||||
@@ -6,11 +6,9 @@ import (
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -35,9 +33,8 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
if config.ApiURL != "https://gpt.bemore.lol" && config.ApiURL != "https://api.chat-plus.net" {
|
||||
config.ApiURL = "https://api.chat-plus.net"
|
||||
}
|
||||
// rewrite api key
|
||||
config.ApiURL = "https://api.chat-plus.net"
|
||||
client := plus.NewClient(config)
|
||||
name := fmt.Sprintf("mj-service-plus-%d", k)
|
||||
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
|
||||
@@ -54,7 +51,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
|
||||
continue
|
||||
}
|
||||
// create mj client
|
||||
client := NewClient(config, appConfig.ProxyURL, appConfig.ImgCdnURL)
|
||||
client := NewClient(config, appConfig.ProxyURL)
|
||||
|
||||
name := fmt.Sprintf("MjService-%d", k)
|
||||
// create mj service
|
||||
@@ -98,6 +95,9 @@ func (p *ServicePool) CheckTaskNotify() {
|
||||
continue
|
||||
}
|
||||
client := p.Clients.Get(userId)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte("Task Updated"))
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -120,17 +120,17 @@ func (p *ServicePool) DownloadImages() {
|
||||
if v.OrgURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try to download image: %s", v.OrgURL)
|
||||
var imgURL string
|
||||
var err error
|
||||
if v.UseProxy {
|
||||
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
|
||||
task, _ := servicePlus.Client.QueryTask(v.TaskId)
|
||||
if task.ImageUrl != "" {
|
||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(task.ImageUrl, false)
|
||||
}
|
||||
if len(task.Buttons) > 0 {
|
||||
v.Hash = getImageHash(task.Buttons[0].CustomId)
|
||||
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
|
||||
}
|
||||
} else {
|
||||
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
|
||||
@@ -138,12 +138,17 @@ func (p *ServicePool) DownloadImages() {
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
continue
|
||||
} else {
|
||||
logger.Info("download image %v successfully.", v.OrgURL)
|
||||
}
|
||||
|
||||
v.ImgURL = imgURL
|
||||
p.db.Updates(&v)
|
||||
|
||||
client := p.Clients.Get(uint(v.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte("Task Updated"))
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -179,7 +184,7 @@ func (p *ServicePool) Notify(data plus.CBReq) error {
|
||||
return nil
|
||||
}
|
||||
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
|
||||
return servicePlus.Notify(data, job)
|
||||
return servicePlus.Notify(job)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -211,40 +216,7 @@ func (p *ServicePool) SyncTaskProgress() {
|
||||
}
|
||||
|
||||
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
|
||||
task, err := servicePlus.Client.QueryTask(v.TaskId)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// 任务失败了
|
||||
if task.FailReason != "" {
|
||||
p.db.Model(&model.MidJourneyJob{Id: v.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": -1,
|
||||
"err_msg": task.FailReason,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if len(task.Buttons) > 0 {
|
||||
v.Hash = getImageHash(task.Buttons[0].CustomId)
|
||||
}
|
||||
oldProgress := v.Progress
|
||||
v.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||||
v.Prompt = task.PromptEn
|
||||
if task.ImageUrl != "" {
|
||||
v.OrgURL = task.ImageUrl
|
||||
}
|
||||
v.UseProxy = true
|
||||
v.MessageId = task.Id
|
||||
|
||||
p.db.Updates(&v)
|
||||
|
||||
if task.Status == "SUCCESS" {
|
||||
// release lock task
|
||||
atomic.AddInt32(&servicePlus.HandledTaskNum, -1)
|
||||
}
|
||||
// 通知前端更新任务进度
|
||||
if oldProgress != v.Progress {
|
||||
p.notifyQueue.RPush(v.UserId)
|
||||
}
|
||||
_ = servicePlus.Notify(v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,11 +235,3 @@ func (p *ServicePool) getServicePlus(name string) *plus.Service {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getImageHash(action string) string {
|
||||
split := strings.Split(action, "::")
|
||||
if len(split) > 5 {
|
||||
return split[4]
|
||||
}
|
||||
return split[len(split)-1]
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ func (s *Service) Run() {
|
||||
}
|
||||
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[task.Id] = time.Now()
|
||||
s.taskStartTimes[int(task.Id)] = time.Now()
|
||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
||||
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func (s *Service) Notify(data CBReq) {
|
||||
job.OrgURL = data.Image.URL
|
||||
if s.client.Config.UseCDN {
|
||||
job.UseProxy = true
|
||||
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.imgCdnURL)
|
||||
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL)
|
||||
}
|
||||
|
||||
res = s.db.Updates(&job)
|
||||
|
||||
@@ -56,7 +56,7 @@ func (js *PayJS) Pay(param JPayReq) JPayReps {
|
||||
}
|
||||
p.Add("mchid", js.config.AppId)
|
||||
|
||||
p.Add("Sign", js.sign(p))
|
||||
p.Add("sign", js.sign(p))
|
||||
|
||||
cli := http.Client{}
|
||||
apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
|
||||
|
||||
Reference in New Issue
Block a user