refactor mj service, add mj service pool support

This commit is contained in:
RockYang 2023-12-12 18:33:24 +08:00
parent c012f0c4c5
commit cf758d773e
13 changed files with 201 additions and 320 deletions

View File

@ -33,7 +33,7 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload",
},
},
MjConfig: types.MidJourneyConfig{Enabled: false},
MjConfigs: types.MidJourneyConfig{Enabled: false},
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},

View File

@ -18,7 +18,7 @@ type AppConfig struct {
AesEncryptKey string
SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config
MjConfig MidJourneyConfig // mj 绘画配置
MjConfigs []MidJourneyConfig // mj 绘画配置池子
WeChatBot bool // 是否启用微信机器人
SdConfig StableDiffusionConfig // sd 绘画配置
@ -116,7 +116,7 @@ type ChatConfig struct {
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
ContextDeep int `json:"context_deep"` // 上下文深度
DallApiURL string `json:"dall_api_url"` // dall-e3 绘图 API 地址
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
}
type Platform string

View File

@ -11,28 +11,15 @@ const (
TaskImage = TaskType("image")
TaskUpscale = TaskType("upscale")
TaskVariation = TaskType("variation")
TaskTxt2Img = TaskType("text2img")
)
// TaskSrc 任务来源
type TaskSrc string
const (
TaskSrcChat = TaskSrc("chat") // 来自聊天页面
TaskSrcImg = TaskSrc("img") // 专业绘画页面
)
// MjTask MidJourney 任务
type MjTask struct {
Id int `json:"id"`
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
ChatId string `json:"chat_id,omitempty"`
RoleId int `json:"role_id,omitempty"`
Icon string `json:"icon,omitempty"`
Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
@ -42,7 +29,6 @@ type MjTask struct {
type SdTask struct {
Id int `json:"id"` // job 数据库ID
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`

View File

@ -12,9 +12,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
"time"
)
@ -40,20 +38,6 @@ func NewMidJourneyHandler(
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
h.mjService.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
@ -72,7 +56,7 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) {
if !h.App.Config.MjConfig.Enabled {
if !h.App.Config.MjConfigs[0].Enabled {
resp.ERROR(c, "MidJourney service is disabled")
return
}

View File

@ -165,23 +165,6 @@ func main() {
// MidJourney 机器人
fx.Provide(mj.NewBot),
fx.Provide(mj.NewClient),
fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
if config.MjConfig.Enabled {
err := bot.Run()
if err != nil {
log.Fatal("MidJourney 服务启动失败:", err)
}
}
}),
fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) {
if config.MjConfig.Enabled {
go func() {
mjService.Run()
}()
}
}),
// Stable Diffusion 机器人
fx.Provide(sd.NewService),
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
@ -256,13 +239,11 @@ func main() {
group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.Any("client", h.Client)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd")
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.Any("client", h.Client)
}),
// 管理后台控制器

View File

@ -23,7 +23,7 @@ type Bot struct {
}
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken)
if err != nil {
return nil, err
}
@ -41,7 +41,7 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
}
return &Bot{
config: &config.MjConfig,
config: &config.MjConfigs,
bot: discord,
service: service,
}, nil

View File

@ -2,6 +2,7 @@ package mj
import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"github.com/imroc/req/v3"
"time"
@ -14,13 +15,13 @@ type Client struct {
config *types.MidJourneyConfig
}
func NewClient(config *types.AppConfig) *Client {
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second)
// set proxy URL
if config.ProxyURL != "" {
client.SetProxyURL(config.ProxyURL)
if utils.IsEmptyValue(proxy) {
client.SetProxyURL(proxy)
}
return &Client{client: client, config: &config.MjConfig}
return &Client{client: client, config: config}
}
func (c *Client) Imagine(prompt string) error {

38
api/service/mj/pool.go Normal file
View File

@ -0,0 +1,38 @@
package mj
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []Service
taskQueue *store.RedisQueue
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
// create mj client and service
for _, config := range appConfig.MjConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(&config, appConfig.ProxyURL)
// create mj service
service := NewService()
}
return &ServicePool{
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
}
}
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}

View File

@ -2,63 +2,63 @@ package mj
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
"strings"
"sync/atomic"
"time"
)
// MJ 绘画服务
const RunningJobKey = "MidJourney_Running_Job"
// Service MJ 绘画服务
type Service struct {
client *Client // MJ 客户端
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
uploadManager *oss.UploaderManager
Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
proxyURL string
name string // service name
client *Client // MJ client
taskQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
proxyURL string
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
snowflake *service.Snowflake
}
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
return &Service{
redis: redisCli,
db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
client: client,
uploadManager: manager,
Clients: types.NewLMap[string, *types.WsClient](),
ChatClients: types.NewLMap[string, *types.WsClient](),
proxyURL: config.ProxyURL,
name: name,
db: db,
taskQueue: queue,
client: client,
uploadManager: manager,
taskTimeout: timeout,
proxyURL: config.ProxyURL,
taskStartTimes: make(map[int]time.Time, 0),
}
}
func (s *Service) Run() {
logger.Info("Starting MidJourney job consumer.")
ctx := context.Background()
logger.Infof("Starting MidJourney job consumer for %s", s.name)
for {
_, err := s.redis.Get(ctx, RunningJobKey).Result()
if err == nil { // 队列串行执行
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.MjTask
err = s.taskQueue.LPop(&task)
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
logger.Infof("handle a new MidJourney task: %+v", task)
switch task.Type {
case types.TaskImage:
err = s.client.Imagine(task.Prompt)
@ -70,50 +70,40 @@ func (s *Service) Run() {
case types.TaskVariation:
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
}
if err != nil {
logger.Error("绘画任务执行失败:", err)
// 删除任务
s.db.Delete(&model.MidJourneyJob{Id: uint(task.Id)})
// 推送任务到前端
client := s.Clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, vo.MidJourneyJob{
Type: task.Type.String(),
UserId: task.UserId,
MessageId: task.MessageId,
Progress: -1,
Prompt: task.Prompt,
})
}
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
atomic.AddInt32(&s.handledTaskNum, -1)
continue
}
// 更新任务的执行状态
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
}
}
func (s *Service) PushTask(task types.MjTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
}
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
}
}
}
func (s *Service) Notify(data CBReq) {
taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
if err != nil { // 过期任务,丢弃
logger.Warn("任务已过期:", err)
return
}
var task types.MjTask
err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err)
return
}
// extract the task ID
split := strings.Split(data.Prompt, " ")
var job model.MidJourneyJob
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
if res.Error == nil && data.Status == Finished {
@ -121,137 +111,37 @@ func (s *Service) Notify(data CBReq) {
return
}
if task.Src == types.TaskSrcImg { // 绘画任务
var job model.MidJourneyJob
res := s.db.Where("id = ?", task.Id).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
res = s.db.Where("task_id = ?", split[0]).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL // save origin image
// 任务完成,将最终的图片下载下来
if data.Progress == 100 {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imgURL
} else {
// 临时图片直接保存,访问的时候使用代理进行转发
job.ImgURL = data.Image.URL
}
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
// upload image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imgURL
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
if data.Progress < 100 {
image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
if err == nil {
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
// 推送任务到前端
client := s.Clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
} else if task.Src == types.TaskSrcChat { // 聊天任务
wsClient := s.ChatClients.Get(task.SessionId)
if data.Status == Finished {
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
utils.ReplyMessage(wsClient, content)
}
// download image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download image: ", err)
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
utils.ReplyMessage(wsClient, content)
}
return
}
tx := s.db.Begin()
data.Image.URL = imgURL
message := model.HistoryMessage{
UserId: uint(task.UserId),
ChatId: task.ChatId,
RoleId: uint(task.RoleId),
Type: types.MjMsg,
Icon: task.Icon,
Content: utils.JsonEncode(data),
Tokens: 0,
UseContext: false,
}
res = tx.Create(&message)
if res.Error != nil {
logger.Error("error with update database: ", err)
return
}
// save the job
job.UserId = task.UserId
job.Type = task.Type.String()
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Prompt = data.Prompt
job.ImgURL = imgURL
job.Progress = data.Progress
job.Hash = data.Image.Hash
job.CreatedAt = time.Now()
res = tx.Create(&job)
if res.Error != nil {
logger.Error("error with update database: ", err)
tx.Rollback()
return
}
tx.Commit()
}
if wsClient == nil { // 客户端断线,则丢弃
logger.Errorf("Client is offline: %+v", data)
return
}
if data.Status == Finished {
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
// 本次绘画完毕,移除客户端
s.ChatClients.Delete(task.SessionId)
} else {
// 使用代理临时转发图片
if data.Image.URL != "" {
image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
if err == nil {
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
}
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
// 更新用户剩余绘图次数
// TODO: 放大图片是否需要消耗绘图次数?
if data.Status == Finished {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// 解除任务锁定
s.redis.Del(context.Background(), RunningJobKey)
// update user's img calls
s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// release lock task
atomic.AddInt32(&s.handledTaskNum, -1)
}
}

View File

@ -55,7 +55,7 @@ func (s *HuPiPayService) Sign(params map[string]string) string {
var data string
keys := make([]string, 0, 0)
params["appid"] = s.appId
for key, _ := range params {
for key := range params {
keys = append(keys, key)
}
sort.Strings(keys)

View File

@ -6,13 +6,14 @@ type MidJourneyJob struct {
Id uint `gorm:"primarykey;column:id"`
Type string
UserId int
TaskId string
MessageId string
ReferenceId string
ImgURL string
OrgURL string // 原图地址
Hash string // message hash
Progress int
Prompt string
Started bool
CreatedAt time.Time
}

View File

@ -9,9 +9,9 @@ type MidJourneyJob struct {
MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"`
ImgURL string `json:"img_url"`
OrgURL string `json:"org_url"`
Hash string `json:"hash"`
Progress int `json:"progress"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
Started bool `json:"started"`
}

View File

@ -504,72 +504,72 @@ const socket = ref(null)
const imgCalls = ref(0)
const loading = ref(false)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8");
reader.onload = () => {
const data = JSON.parse(String(reader.result));
let isNew = true
if (data.progress === 100) {
for (let i = 0; i < finishedJobs.value.length; i++) {
if (finishedJobs.value[i].id === data.id) {
isNew = false
break
}
}
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
runningJobs.value.splice(i, 1)
break
}
}
if (isNew) {
finishedJobs.value.unshift(data)
}
} else if (data.progress === -1) { //
ElNotification({
title: '任务执行失败',
message: "提示词:" + data['prompt'],
type: 'error',
})
runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
} else {
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
isNew = false
runningJobs.value[i] = data
break
}
}
if (isNew) {
runningJobs.value.push(data)
}
}
}
}
});
_socket.addEventListener('close', () => {
ElMessage.error("Websocket 已经断开,正在重新连接服务器")
connect()
});
}
// const connect = () => {
// let host = process.env.VUE_APP_WS_HOST
// if (host === '') {
// if (location.protocol === 'https:') {
// host = 'wss://' + location.host;
// } else {
// host = 'ws://' + location.host;
// }
// }
// const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
// _socket.addEventListener('open', () => {
// socket.value = _socket;
// });
//
// _socket.addEventListener('message', event => {
// if (event.data instanceof Blob) {
// const reader = new FileReader();
// reader.readAsText(event.data, "UTF-8");
// reader.onload = () => {
// const data = JSON.parse(String(reader.result));
// let isNew = true
// if (data.progress === 100) {
// for (let i = 0; i < finishedJobs.value.length; i++) {
// if (finishedJobs.value[i].id === data.id) {
// isNew = false
// break
// }
// }
// for (let i = 0; i < runningJobs.value.length; i++) {
// if (runningJobs.value[i].id === data.id) {
// runningJobs.value.splice(i, 1)
// break
// }
// }
// if (isNew) {
// finishedJobs.value.unshift(data)
// }
// } else if (data.progress === -1) { //
// ElNotification({
// title: '',
// message: "" + data['prompt'],
// type: 'error',
// })
// runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
//
// } else {
// for (let i = 0; i < runningJobs.value.length; i++) {
// if (runningJobs.value[i].id === data.id) {
// isNew = false
// runningJobs.value[i] = data
// break
// }
// }
// if (isNew) {
// runningJobs.value.push(data)
// }
// }
// }
// }
// });
//
// _socket.addEventListener('close', () => {
// ElMessage.error("Websocket ")
// connect()
// });
// }
const translatePrompt = () => {
loading.value = true