refactor: refactor stable diffusion service, add service pool support

This commit is contained in:
RockYang 2023-12-14 16:48:54 +08:00
parent 10ba1430f9
commit d2991e60b6
14 changed files with 266 additions and 253 deletions

View File

@ -169,9 +169,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
var tokenString string
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/sd/client" {
} else if c.Request.URL.Path == "/api/chat/new" {
tokenString = c.Query("token")
} else {
tokenString = c.GetHeader(types.UserAuthHeader)

View File

@ -33,7 +33,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload",
},
},
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
}

View File

@ -16,11 +16,11 @@ type AppConfig struct {
Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
AesEncryptKey string
SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config
MjConfigs []MidJourneyConfig // mj 绘画配置池子
WeChatBot bool // 是否启用微信机器人
SdConfig StableDiffusionConfig // sd 绘画配置
SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config
MjConfigs []MidJourneyConfig // mj AI draw service pool
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig
AlipayConfig AlipayConfig

View File

@ -8,47 +8,30 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
)
type SdJobHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
service *sd.Service
redis *redis.Client
db *gorm.DB
pool *sd.ServicePool
}
func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler {
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool) *SdJobHandler {
h := SdJobHandler{
redis: redisCli,
db: db,
service: service,
db: db,
pool: pool,
}
h.App = app
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
// 删除旧的连接
h.service.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
@ -56,6 +39,11 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
return false
}
if !h.pool.HasAvailableService() {
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
return false
@ -67,11 +55,6 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) {
if !h.App.Config.SdConfig.Enabled {
resp.ERROR(c, "Stable Diffusion service is disabled")
return
}
if !h.checkLimits(c) {
return
}
@ -129,7 +112,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Params: utils.JsonEncode(params),
Prompt: data.Prompt,
Progress: 0,
Started: false,
CreatedAt: time.Now(),
}
res := h.db.Create(&job)
@ -138,7 +120,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return
}
h.service.PushTask(types.SdTask{
h.pool.PushTask(types.SdTask{
Id: int(job.Id),
SessionId: data.SessionId,
Type: types.TaskImage,
@ -146,15 +128,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Params: params,
UserId: userId,
})
var jobVo vo.SdJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.service.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
resp.SUCCESS(c)
}
@ -193,12 +167,22 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
if err != nil {
continue
}
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
}
if item.Progress < 100 {
// 30 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
// 10 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
h.db.Delete(&item)
continue
}
// 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, "")
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
jobs = append(jobs, job)
}

View File

@ -167,14 +167,7 @@ func main() {
fx.Provide(mj.NewServicePool),
// Stable Diffusion 机器人
fx.Provide(sd.NewService),
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
if config.SdConfig.Enabled {
go func() {
service.Run()
}()
}
}),
fx.Provide(sd.NewServicePool),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),

View File

@ -1,21 +1,21 @@
{
"data": [
"task(s95jqt5jr8yppcp)",
"A beautiful Chinese girl in a garden",
"task(owy5niy1sbbnlq0)",
"A beautiful Chinese girl plays the guitar on the beach. She is dressed in a flowing dress that matches the colors of the sunset. With her eyes closed, she strums the guitar with passion and confidence, her fingers dancing gracefully on the strings. The painting employs a vibrant color palette, capturing the warmth of the setting sun blending with the serene hues of the ocean. The artist uses a combination of impressionistic and realistic brushstrokes to convey both the girl's delicate features and the dynamic movement of the waves. The rendering effect creates a dream-like atmosphere, as if the viewer is being transported to a magical realm where music and nature intertwine. The picture is bathed in a soft, golden light, casting a warm glow on the girl's face, illuminating her joy and connection to the music she creates.",
"",
[],
30,
"Euler a",
"DPM++ 3M SDE Karras",
1,
1,
7,
512,
512,
true,
false,
0.7,
2,
"Latent",
10,
0,
0,
0,
"Use same checkpoint",
@ -33,6 +33,9 @@
0,
0,
0,
null,
null,
null,
false,
false,
"positive",
@ -55,13 +58,22 @@
false,
false,
0,
[
],
null,
null,
false,
null,
null,
false,
null,
null,
false,
50,
[],
"",
"",
""
],
"event_data": null,
"fn_index": 95,
"session_hash": "eqwumnt3rov"
"fn_index": 316,
"session_hash": "ttr8efgt63g"
}

View File

@ -60,7 +60,7 @@ func (p *ServicePool) PushTask(task types.MjTask) {
p.taskQueue.RPush(task)
}
// HasAvailableService check if has available mj service in pool
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@ -2,7 +2,6 @@ package mj
import (
"chatplus/core/types"
"chatplus/service"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
@ -24,7 +23,6 @@ type Service struct {
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
snowflake *service.Snowflake
}
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
@ -127,6 +125,12 @@ func (s *Service) Notify(data CBReq) {
job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
// upload image
if data.Status == Finished {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
@ -135,12 +139,7 @@ func (s *Service) Notify(data CBReq) {
return
}
job.ImgURL = imgURL
}
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
s.db.Updates(&job)
}
if data.Status == Finished {

52
api/service/sd/pool.go Normal file
View File

@ -0,0 +1,52 @@
package sd
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"fmt"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.SdConfigs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%d", k)
service := NewService(name, 4, 600, &config, queue, db, manager)
// run sd service
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: queue,
services: services,
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@ -5,84 +5,96 @@ import (
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/json"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"io"
"os"
"strconv"
"sync/atomic"
"time"
)
// SD 绘画服务
const RunningJobKey = "StableDiffusion_Running_Job"
type Service struct {
httpClient *req.Client
config *types.StableDiffusionConfig
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
uploadManager *oss.UploaderManager
Clients *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池
httpClient *req.Client
config *types.StableDiffusionConfig
taskQueue *store.RedisQueue
db *gorm.DB
uploadManager *oss.UploaderManager
name string // service name
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
}
func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service {
func NewService(name string, maxTaskNum int32, timeout int64, config *types.StableDiffusionConfig, queue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
return &Service{
config: &config.SdConfig,
httpClient: req.C(),
redis: redisCli,
db: db,
uploadManager: manager,
Clients: types.NewLMap[string, *types.WsClient](),
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
name: name,
config: config,
httpClient: req.C(),
taskQueue: queue,
db: db,
uploadManager: manager,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time),
}
}
func (s *Service) Run() {
logger.Info("Starting StableDiffusion job consumer.")
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, RunningJobKey).Result()
if err == nil { // 队列串行执行
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask
err = s.taskQueue.LPop(&task)
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 {
s.taskQueue.RPush(task)
}
task.RetryCount += 1
time.Sleep(time.Second * 3)
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
continue
}
// 更新任务的执行状态
s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
}
}
// PushTask 推送任务到队列
func (s *Service) PushTask(task types.SdTask) {
logger.Infof("add a new Stable Diffusion Task: %+v", task)
s.taskQueue.RPush(task)
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
}
// Txt2Img 文生图 API
@ -237,9 +249,8 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
}
func (s *Service) callback(data CBReq) {
// 释放任务锁
s.redis.Del(context.Background(), RunningJobKey)
client := s.Clients.Get(data.SessionId)
// release task num
atomic.AddInt32(&s.handledTaskNum, -1)
if data.Success { // 任务成功
var job model.SdJob
res := s.db.Where("id = ?", data.JobId).First(&job)
@ -259,13 +270,15 @@ func (s *Service) callback(data CBReq) {
params.Seed = data.Seed
if data.ImageName != "" { // 下载图片
imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
if data.Progress == 100 {
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imageURL
}
job.ImgURL = imageURL
}
job.Params = utils.JsonEncode(params)
@ -275,38 +288,16 @@ func (s *Service) callback(data CBReq) {
return
}
var jobVo vo.SdJob
err = utils.CopyObject(job, &jobVo)
if err != nil {
logger.Error("error with copy object: ", err)
return
}
if data.Progress < 100 && data.ImageData != "" {
jobVo.ImgURL = data.ImageData
}
logger.Infof("绘图进度:%d", data.Progress)
logger.Debugf("绘图进度:%d", data.Progress)
// 扣减绘图次数
if data.Progress == 100 {
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", jobVo.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
// 推送任务到前端
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
} else { // 任务失败
logger.Error("任务执行失败:", data.Message)
// 删除任务
s.db.Delete(&model.SdJob{Id: uint(data.JobId)})
// 推送消息到前端
if client != nil {
utils.ReplyChunkMessage(client, vo.SdJob{
Id: uint(data.JobId),
Progress: -1,
TaskId: data.TaskId,
})
}
// update the task progress
s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumn("progress", -1)
}
}

View File

@ -11,7 +11,6 @@ type SdJob struct {
Progress int
Prompt string
Params string
Started bool
CreatedAt time.Time
}

View File

@ -15,5 +15,4 @@ type SdJob struct {
Progress int `json:"progress"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
Started bool `json:"started"`
}

View File

@ -266,7 +266,6 @@
翻译并重写
</el-button>
</el-tooltip>
</div>
</div>
</div>
@ -580,7 +579,7 @@ const fetchRunningJobs = (userId) => {
}
runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 10000)
setTimeout(() => fetchRunningJobs(userId), 5000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
@ -591,7 +590,7 @@ const fetchFinishJobs = (userId) => {
//
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 10000)
setTimeout(() => fetchFinishJobs(userId), 5000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})

View File

@ -241,7 +241,7 @@
</div>
</div>
<div class="param-line">
<div class="param-line" v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.8)">
<el-input
v-model="params.prompt"
:autosize="{ minRows: 4, maxRows: 6 }"
@ -251,6 +251,30 @@
/>
</div>
<div style="padding: 10px">
<el-button type="primary" @click="translatePrompt" size="small">
<el-icon style="margin-right: 6px;font-size: 18px;">
<Refresh/>
</el-icon>
翻译
</el-button>
<el-tooltip
class="box-item"
effect="dark"
raw-content
content="使用 AI 翻译并重写提示词,<br/>增加更多细节,风格等描述"
placement="top-end"
>
<el-button type="success" @click="rewritePrompt" size="small">
<el-icon style="margin-right: 6px;font-size: 18px;">
<Refresh/>
</el-icon>
翻译并重写
</el-button>
</el-tooltip>
</div>
<div class="param-line pt">
<span>反向提示词</span>
<el-tooltip
@ -272,12 +296,8 @@
/>
</div>
<div class="param-line pt">
<el-form-item label="剩余次数">
<template #default>
<el-tag type="info">{{ imgCalls }}</el-tag>
</template>
</el-form-item>
<div class="param-line" style="padding: 10px">
<el-tag type="success">绘图可用额度{{ imgCalls }}</el-tag>
</div>
</el-form>
</div>
@ -478,21 +498,21 @@
<script setup>
import {onMounted, ref} from "vue"
import {DocumentCopy, InfoFilled, Orange, Picture} from "@element-plus/icons-vue";
import {DocumentCopy, InfoFilled, Orange, Picture, Refresh} from "@element-plus/icons-vue";
import {httpGet, httpPost} from "@/utils/http";
import {ElMessage, ElNotification} from "element-plus";
import ItemList from "@/components/ItemList.vue";
import Clipboard from "clipboard";
import {checkSession} from "@/action/session";
import {useRouter} from "vue-router";
import {getSessionId, getUserToken} from "@/store/session";
import {removeArrayItem} from "@/utils/libs";
import {getSessionId} from "@/store/session";
const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150)
const fullImgHeight = ref(window.innerHeight - 60)
const showTaskDialog = ref(false)
const item = ref({})
const loading = ref(false)
window.onresize = () => {
listBoxHeight.value = window.innerHeight - 40
@ -515,116 +535,84 @@ const params = ref({
hd_scale_alg: scaleAlg[0],
hd_steps: 10,
prompt: "",
negative_prompt: "nsfw, paintings, cartoon, anime, sketches, low quality,easynegative,ng_deepnegative _v1 75t,(worst quality:2),(low quality:2),(normalquality:2),lowres,bad anatomy,bad hands,normal quality,((monochrome)),((grayscale)),((watermark))",
negative_prompt: "nsfw, paintings,low quality,easynegative,ng_deepnegative ,lowres,bad anatomy,bad hands,bad feet",
})
const runningJobs = ref([])
const finishedJobs = ref([])
const previewImgList = ref([])
const router = useRouter()
//
const _params = router.currentRoute.value.params["copyParams"]
if (_params) {
params.value = JSON.parse(_params)
}
const socket = ref(null)
const imgCalls = ref(0)
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/sd/client?session_id=${getSessionId()}&token=${getUserToken()}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
const rewritePrompt = () => {
loading.value = true
httpPost("/api/prompt/rewrite", {"prompt": params.value.prompt}).then(res => {
params.value.prompt = res.data
loading.value = false
}).catch(e => {
ElMessage.error("翻译失败:" + e.message)
})
}
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8");
reader.onload = () => {
const data = JSON.parse(String(reader.result));
let append = true
if (data.progress === 100) { //
for (let i = 0; i < finishedJobs.value.length; i++) {
if (finishedJobs.value[i].id === data.id) {
append = false
break
}
}
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
runningJobs.value.splice(i, 1)
break
}
}
if (append) {
finishedJobs.value.unshift(data)
}
previewImgList.value.unshift(data["img_url"])
} else if (data.progress === -1) { //
ElNotification({
title: '任务执行失败',
message: "任务ID" + data['task_id'],
type: 'error',
})
runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
} else { //
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
append = false
runningJobs.value[i] = data
break
}
}
if (append) {
runningJobs.value.push(data)
}
}
}
}
});
_socket.addEventListener('close', () => {
connect()
});
const translatePrompt = () => {
loading.value = true
httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => {
params.value.prompt = res.data
loading.value = false
}).catch(e => {
ElMessage.error("翻译失败:" + e.message)
})
}
onMounted(() => {
checkSession().then(user => {
imgCalls.value = user['img_calls']
//
httpGet(`/api/sd/jobs?status=0&user_id=${user['id']}`).then(res => {
runningJobs.value = res.data
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
//
httpGet(`/api/sd/jobs?status=1&user_id=${user['id']}`).then(res => {
finishedJobs.value = res.data
previewImgList.value = []
for (let index in finishedJobs.value) {
previewImgList.value.push(finishedJobs.value[index]["img_url"])
}
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
fetchRunningJobs(user.id)
fetchFinishJobs(user.id)
// socket
connect();
}).catch(() => {
router.push('/login')
});
const fetchRunningJobs = (userId) => {
//
httpGet(`/api/sd/jobs?status=0&user_id=${userId}`).then(res => {
const jobs = res.data
const _jobs = []
for (let i = 0; i < jobs.length; i++) {
if (jobs[i].progress === -1) {
ElNotification({
title: '任务执行失败',
message: "任务ID" + jobs[i]['task_id'],
type: 'error',
})
continue
}
_jobs.push(jobs[i])
}
runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 5000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
//
const fetchFinishJobs = (userId) => {
httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 5000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
const clipboard = new Clipboard('.copy-prompt');
clipboard.on('success', () => {
ElMessage.success("复制成功!");