opt: enable use cdn url for mj-plus

This commit is contained in:
RockYang
2024-01-28 21:56:25 +08:00
parent f08a7862de
commit bf65746d00
17 changed files with 193 additions and 158 deletions

View File

@@ -6,11 +6,9 @@ import (
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"fmt"
"github.com/go-redis/redis/v8"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
@@ -35,9 +33,8 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
if config.Enabled == false {
continue
}
if config.ApiURL != "https://gpt.bemore.lol" && config.ApiURL != "https://api.chat-plus.net" {
config.ApiURL = "https://api.chat-plus.net"
}
// rewrite api key
config.ApiURL = "https://api.chat-plus.net"
client := plus.NewClient(config)
name := fmt.Sprintf("mj-service-plus-%d", k)
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
@@ -54,7 +51,7 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
continue
}
// create mj client
client := NewClient(config, appConfig.ProxyURL, appConfig.ImgCdnURL)
client := NewClient(config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
@@ -98,6 +95,9 @@ func (p *ServicePool) CheckTaskNotify() {
continue
}
client := p.Clients.Get(userId)
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
@@ -120,17 +120,17 @@ func (p *ServicePool) DownloadImages() {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string
var err error
if v.UseProxy {
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if task.ImageUrl != "" {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(task.ImageUrl, false)
}
if len(task.Buttons) > 0 {
v.Hash = getImageHash(task.Buttons[0].CustomId)
v.Hash = plus.GetImageHash(task.Buttons[0].CustomId)
}
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
}
} else {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
@@ -138,12 +138,17 @@ func (p *ServicePool) DownloadImages() {
if err != nil {
logger.Error("error with download image: ", err)
continue
} else {
logger.Info("download image %v successfully.", v.OrgURL)
}
v.ImgURL = imgURL
p.db.Updates(&v)
client := p.Clients.Get(uint(v.UserId))
if client == nil {
continue
}
err = client.Send([]byte("Task Updated"))
if err != nil {
continue
@@ -179,7 +184,7 @@ func (p *ServicePool) Notify(data plus.CBReq) error {
return nil
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
return servicePlus.Notify(data, job)
return servicePlus.Notify(job)
}
return nil
@@ -211,40 +216,7 @@ func (p *ServicePool) SyncTaskProgress() {
}
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
task, err := servicePlus.Client.QueryTask(v.TaskId)
if err != nil {
continue
}
// 任务失败了
if task.FailReason != "" {
p.db.Model(&model.MidJourneyJob{Id: v.Id}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": task.FailReason,
})
continue
}
if len(task.Buttons) > 0 {
v.Hash = getImageHash(task.Buttons[0].CustomId)
}
oldProgress := v.Progress
v.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
v.Prompt = task.PromptEn
if task.ImageUrl != "" {
v.OrgURL = task.ImageUrl
}
v.UseProxy = true
v.MessageId = task.Id
p.db.Updates(&v)
if task.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&servicePlus.HandledTaskNum, -1)
}
// 通知前端更新任务进度
if oldProgress != v.Progress {
p.notifyQueue.RPush(v.UserId)
}
_ = servicePlus.Notify(v)
}
}
@@ -263,11 +235,3 @@ func (p *ServicePool) getServicePlus(name string) *plus.Service {
}
return nil
}
func getImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}