refactor mj service, add mj service pool support

This commit is contained in:
RockYang 2023-12-12 18:33:24 +08:00
parent c012f0c4c5
commit cf758d773e
13 changed files with 201 additions and 320 deletions

View File

@ -33,7 +33,7 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload", BasePath: "./static/upload",
}, },
}, },
MjConfig: types.MidJourneyConfig{Enabled: false}, MjConfigs: types.MidJourneyConfig{Enabled: false},
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"}, SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
WeChatBot: false, WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},

View File

@ -18,7 +18,7 @@ type AppConfig struct {
AesEncryptKey string AesEncryptKey string
SmsConfig AliYunSmsConfig // AliYun send message service config SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config OSS OSSConfig // OSS config
MjConfig MidJourneyConfig // mj 绘画配置 MjConfigs []MidJourneyConfig // mj 绘画配置池子
WeChatBot bool // 是否启用微信机器人 WeChatBot bool // 是否启用微信机器人
SdConfig StableDiffusionConfig // sd 绘画配置 SdConfig StableDiffusionConfig // sd 绘画配置

View File

@ -11,28 +11,15 @@ const (
TaskImage = TaskType("image") TaskImage = TaskType("image")
TaskUpscale = TaskType("upscale") TaskUpscale = TaskType("upscale")
TaskVariation = TaskType("variation") TaskVariation = TaskType("variation")
TaskTxt2Img = TaskType("text2img")
)
// TaskSrc 任务来源
type TaskSrc string
const (
TaskSrcChat = TaskSrc("chat") // 来自聊天页面
TaskSrcImg = TaskSrc("img") // 专业绘画页面
) )
// MjTask MidJourney 任务 // MjTask MidJourney 任务
type MjTask struct { type MjTask struct {
Id int `json:"id"` Id int `json:"id"`
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
ChatId string `json:"chat_id,omitempty"`
RoleId int `json:"role_id,omitempty"`
Icon string `json:"icon,omitempty"`
Index int `json:"index,omitempty"` Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"` MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"` MessageHash string `json:"message_hash,omitempty"`
@ -42,7 +29,6 @@ type MjTask struct {
type SdTask struct { type SdTask struct {
Id int `json:"id"` // job 数据库ID Id int `json:"id"` // job 数据库ID
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`

View File

@ -12,9 +12,7 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
"strings" "strings"
"time" "time"
) )
@ -40,20 +38,6 @@ func NewMidJourneyHandler(
return &h return &h
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
h.mjService.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool { func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db) user, err := utils.GetLoginUser(c, h.db)
if err != nil { if err != nil {
@ -72,7 +56,7 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) { func (h *MidJourneyHandler) Image(c *gin.Context) {
if !h.App.Config.MjConfig.Enabled { if !h.App.Config.MjConfigs[0].Enabled {
resp.ERROR(c, "MidJourney service is disabled") resp.ERROR(c, "MidJourney service is disabled")
return return
} }

View File

@ -165,23 +165,6 @@ func main() {
// MidJourney 机器人 // MidJourney 机器人
fx.Provide(mj.NewBot), fx.Provide(mj.NewBot),
fx.Provide(mj.NewClient),
fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
if config.MjConfig.Enabled {
err := bot.Run()
if err != nil {
log.Fatal("MidJourney 服务启动失败:", err)
}
}
}),
fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) {
if config.MjConfig.Enabled {
go func() {
mjService.Run()
}()
}
}),
// Stable Diffusion 机器人 // Stable Diffusion 机器人
fx.Provide(sd.NewService), fx.Provide(sd.NewService),
fx.Invoke(func(config *types.AppConfig, service *sd.Service) { fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
@ -256,13 +239,11 @@ func main() {
group.POST("upscale", h.Upscale) group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation) group.POST("variation", h.Variation)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.Any("client", h.Client)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd") group := s.Engine.Group("/api/sd")
group.POST("image", h.Image) group.POST("image", h.Image)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.Any("client", h.Client)
}), }),
// 管理后台控制器 // 管理后台控制器

View File

@ -23,7 +23,7 @@ type Bot struct {
} }
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) { func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.MjConfig.BotToken) discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -41,7 +41,7 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
} }
return &Bot{ return &Bot{
config: &config.MjConfig, config: &config.MjConfigs,
bot: discord, bot: discord,
service: service, service: service,
}, nil }, nil

View File

@ -2,6 +2,7 @@ package mj
import ( import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/utils"
"fmt" "fmt"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"time" "time"
@ -14,13 +15,13 @@ type Client struct {
config *types.MidJourneyConfig config *types.MidJourneyConfig
} }
func NewClient(config *types.AppConfig) *Client { func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second) client := req.C().SetTimeout(10 * time.Second)
// set proxy URL // set proxy URL
if config.ProxyURL != "" { if utils.IsEmptyValue(proxy) {
client.SetProxyURL(config.ProxyURL) client.SetProxyURL(proxy)
} }
return &Client{client: client, config: &config.MjConfig} return &Client{client: client, config: config}
} }
func (c *Client) Imagine(prompt string) error { func (c *Client) Imagine(prompt string) error {

38
api/service/mj/pool.go Normal file
View File

@ -0,0 +1,38 @@
package mj
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []Service
taskQueue *store.RedisQueue
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
// create mj client and service
for _, config := range appConfig.MjConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(&config, appConfig.ProxyURL)
// create mj service
service := NewService()
}
return &ServicePool{
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
}
}
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}

View File

@ -2,63 +2,63 @@ package mj
import ( import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/service"
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store" "chatplus/store"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
"strings"
"sync/atomic"
"time" "time"
) )
// MJ 绘画服务 // Service MJ 绘画服务
const RunningJobKey = "MidJourney_Running_Job"
type Service struct { type Service struct {
client *Client // MJ 客户端 name string // service name
client *Client // MJ client
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
proxyURL string proxyURL string
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
snowflake *service.Snowflake
} }
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service { func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
return &Service{ return &Service{
redis: redisCli, name: name,
db: db, db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), taskQueue: queue,
client: client, client: client,
uploadManager: manager, uploadManager: manager,
Clients: types.NewLMap[string, *types.WsClient](), taskTimeout: timeout,
ChatClients: types.NewLMap[string, *types.WsClient](),
proxyURL: config.ProxyURL, proxyURL: config.ProxyURL,
taskStartTimes: make(map[int]time.Time, 0),
} }
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Info("Starting MidJourney job consumer.") logger.Infof("Starting MidJourney job consumer for %s", s.name)
ctx := context.Background()
for { for {
_, err := s.redis.Get(ctx, RunningJobKey).Result() s.checkTasks()
if err == nil { // 队列串行执行 if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} }
var task types.MjTask var task types.MjTask
err = s.taskQueue.LPop(&task) err := s.taskQueue.LPop(&task)
if err != nil { if err != nil {
logger.Errorf("taking task with error: %v", err) logger.Errorf("taking task with error: %v", err)
continue continue
} }
logger.Infof("Consuming Task: %+v", task)
logger.Infof("handle a new MidJourney task: %+v", task)
switch task.Type { switch task.Type {
case types.TaskImage: case types.TaskImage:
err = s.client.Imagine(task.Prompt) err = s.client.Imagine(task.Prompt)
@ -70,50 +70,40 @@ func (s *Service) Run() {
case types.TaskVariation: case types.TaskVariation:
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash) err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
} }
if err != nil { if err != nil {
logger.Error("绘画任务执行失败:", err) logger.Error("绘画任务执行失败:", err)
// 删除任务 // update the task progress
s.db.Delete(&model.MidJourneyJob{Id: uint(task.Id)}) s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
// 推送任务到前端 atomic.AddInt32(&s.handledTaskNum, -1)
client := s.Clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, vo.MidJourneyJob{
Type: task.Type.String(),
UserId: task.UserId,
MessageId: task.MessageId,
Progress: -1,
Prompt: task.Prompt,
})
}
continue continue
} }
// 更新任务的执行状态 // lock the task until the execute timeout
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) s.taskStartTimes[task.Id] = time.Now()
// 锁定任务执行通道直到任务超时5分钟 atomic.AddInt32(&s.handledTaskNum, 1)
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
} }
} }
func (s *Service) PushTask(task types.MjTask) { // check if current service instance can handle more task
logger.Infof("add a new MidJourney Task: %+v", task) func (s *Service) canHandleTask() bool {
s.taskQueue.RPush(task) handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
}
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
}
}
} }
func (s *Service) Notify(data CBReq) { func (s *Service) Notify(data CBReq) {
taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result() // extract the task ID
if err != nil { // 过期任务,丢弃 split := strings.Split(data.Prompt, " ")
logger.Warn("任务已过期:", err)
return
}
var task types.MjTask
err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err)
return
}
var job model.MidJourneyJob var job model.MidJourneyJob
res := s.db.Where("message_id = ?", data.MessageId).First(&job) res := s.db.Where("message_id = ?", data.MessageId).First(&job)
if res.Error == nil && data.Status == Finished { if res.Error == nil && data.Status == Finished {
@ -121,9 +111,7 @@ func (s *Service) Notify(data CBReq) {
return return
} }
if task.Src == types.TaskSrcImg { // 绘画任务 res = s.db.Where("task_id = ?", split[0]).First(&job)
var job model.MidJourneyJob
res := s.db.Where("id = ?", task.Id).First(&job)
if res.Error != nil { if res.Error != nil {
logger.Warn("非法任务:", res.Error) logger.Warn("非法任务:", res.Error)
return return
@ -133,125 +121,27 @@ func (s *Service) Notify(data CBReq) {
job.Progress = data.Progress job.Progress = data.Progress
job.Prompt = data.Prompt job.Prompt = data.Prompt
job.Hash = data.Image.Hash job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL // save origin image
// 任务完成,将最终的图片下载下来 // upload image
if data.Progress == 100 {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil { if err != nil {
logger.Error("error with download img: ", err.Error()) logger.Error("error with download img: ", err.Error())
return return
} }
job.ImgURL = imgURL job.ImgURL = imgURL
} else {
// 临时图片直接保存,访问的时候使用代理进行转发
job.ImgURL = data.Image.URL
}
res = s.db.Updates(&job) res = s.db.Updates(&job)
if res.Error != nil { if res.Error != nil {
logger.Error("error with update job: ", res.Error) logger.Error("error with update job: ", res.Error)
return return
} }
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
if data.Progress < 100 {
image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
if err == nil {
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
// 推送任务到前端
client := s.Clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
} else if task.Src == types.TaskSrcChat { // 聊天任务
wsClient := s.ChatClients.Get(task.SessionId)
if data.Status == Finished { if data.Status == Finished {
if wsClient != nil && data.ReferenceId != "" { // update user's img calls
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
utils.ReplyMessage(wsClient, content) // release lock task
} atomic.AddInt32(&s.handledTaskNum, -1)
// download image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download image: ", err)
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
utils.ReplyMessage(wsClient, content)
}
return
}
tx := s.db.Begin()
data.Image.URL = imgURL
message := model.HistoryMessage{
UserId: uint(task.UserId),
ChatId: task.ChatId,
RoleId: uint(task.RoleId),
Type: types.MjMsg,
Icon: task.Icon,
Content: utils.JsonEncode(data),
Tokens: 0,
UseContext: false,
}
res = tx.Create(&message)
if res.Error != nil {
logger.Error("error with update database: ", err)
return
}
// save the job
job.UserId = task.UserId
job.Type = task.Type.String()
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Prompt = data.Prompt
job.ImgURL = imgURL
job.Progress = data.Progress
job.Hash = data.Image.Hash
job.CreatedAt = time.Now()
res = tx.Create(&job)
if res.Error != nil {
logger.Error("error with update database: ", err)
tx.Rollback()
return
}
tx.Commit()
}
if wsClient == nil { // 客户端断线,则丢弃
logger.Errorf("Client is offline: %+v", data)
return
}
if data.Status == Finished {
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
// 本次绘画完毕,移除客户端
s.ChatClients.Delete(task.SessionId)
} else {
// 使用代理临时转发图片
if data.Image.URL != "" {
image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
if err == nil {
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
}
}
// 更新用户剩余绘图次数
// TODO: 放大图片是否需要消耗绘图次数?
if data.Status == Finished {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// 解除任务锁定
s.redis.Del(context.Background(), RunningJobKey)
} }
} }

View File

@ -55,7 +55,7 @@ func (s *HuPiPayService) Sign(params map[string]string) string {
var data string var data string
keys := make([]string, 0, 0) keys := make([]string, 0, 0)
params["appid"] = s.appId params["appid"] = s.appId
for key, _ := range params { for key := range params {
keys = append(keys, key) keys = append(keys, key)
} }
sort.Strings(keys) sort.Strings(keys)

View File

@ -6,13 +6,14 @@ type MidJourneyJob struct {
Id uint `gorm:"primarykey;column:id"` Id uint `gorm:"primarykey;column:id"`
Type string Type string
UserId int UserId int
TaskId string
MessageId string MessageId string
ReferenceId string ReferenceId string
ImgURL string ImgURL string
OrgURL string // 原图地址
Hash string // message hash Hash string // message hash
Progress int Progress int
Prompt string Prompt string
Started bool
CreatedAt time.Time CreatedAt time.Time
} }

View File

@ -9,9 +9,9 @@ type MidJourneyJob struct {
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"` ReferenceId string `json:"reference_id"`
ImgURL string `json:"img_url"` ImgURL string `json:"img_url"`
OrgURL string `json:"org_url"`
Hash string `json:"hash"` Hash string `json:"hash"`
Progress int `json:"progress"` Progress int `json:"progress"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Started bool `json:"started"`
} }

View File

@ -504,72 +504,72 @@ const socket = ref(null)
const imgCalls = ref(0) const imgCalls = ref(0)
const loading = ref(false) const loading = ref(false)
const connect = () => { // const connect = () => {
let host = process.env.VUE_APP_WS_HOST // let host = process.env.VUE_APP_WS_HOST
if (host === '') { // if (host === '') {
if (location.protocol === 'https:') { // if (location.protocol === 'https:') {
host = 'wss://' + location.host; // host = 'wss://' + location.host;
} else { // } else {
host = 'ws://' + location.host; // host = 'ws://' + location.host;
} // }
} // }
const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`); // const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
_socket.addEventListener('open', () => { // _socket.addEventListener('open', () => {
socket.value = _socket; // socket.value = _socket;
}); // });
//
_socket.addEventListener('message', event => { // _socket.addEventListener('message', event => {
if (event.data instanceof Blob) { // if (event.data instanceof Blob) {
const reader = new FileReader(); // const reader = new FileReader();
reader.readAsText(event.data, "UTF-8"); // reader.readAsText(event.data, "UTF-8");
reader.onload = () => { // reader.onload = () => {
const data = JSON.parse(String(reader.result)); // const data = JSON.parse(String(reader.result));
let isNew = true // let isNew = true
if (data.progress === 100) { // if (data.progress === 100) {
for (let i = 0; i < finishedJobs.value.length; i++) { // for (let i = 0; i < finishedJobs.value.length; i++) {
if (finishedJobs.value[i].id === data.id) { // if (finishedJobs.value[i].id === data.id) {
isNew = false // isNew = false
break // break
} // }
} // }
for (let i = 0; i < runningJobs.value.length; i++) { // for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) { // if (runningJobs.value[i].id === data.id) {
runningJobs.value.splice(i, 1) // runningJobs.value.splice(i, 1)
break // break
} // }
} // }
if (isNew) { // if (isNew) {
finishedJobs.value.unshift(data) // finishedJobs.value.unshift(data)
} // }
} else if (data.progress === -1) { // // } else if (data.progress === -1) { //
ElNotification({ // ElNotification({
title: '任务执行失败', // title: '',
message: "提示词:" + data['prompt'], // message: "" + data['prompt'],
type: 'error', // type: 'error',
}) // })
runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id) // runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
//
} else { // } else {
for (let i = 0; i < runningJobs.value.length; i++) { // for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) { // if (runningJobs.value[i].id === data.id) {
isNew = false // isNew = false
runningJobs.value[i] = data // runningJobs.value[i] = data
break // break
} // }
} // }
if (isNew) { // if (isNew) {
runningJobs.value.push(data) // runningJobs.value.push(data)
} // }
} // }
} // }
} // }
}); // });
//
_socket.addEventListener('close', () => { // _socket.addEventListener('close', () => {
ElMessage.error("Websocket 已经断开,正在重新连接服务器") // ElMessage.error("Websocket ")
connect() // connect()
}); // });
} // }
const translatePrompt = () => { const translatePrompt = () => {
loading.value = true loading.value = true