mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
refactor: add midjourney pool implementation, add translate prompt for mj drawing
This commit is contained in:
parent
96816c12ca
commit
a398e7a550
@ -33,7 +33,6 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
BasePath: "./static/upload",
|
BasePath: "./static/upload",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
MjConfigs: types.MidJourneyConfig{Enabled: false},
|
|
||||||
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
|
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
|
||||||
WeChatBot: false,
|
WeChatBot: false,
|
||||||
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/handler"
|
"chatplus/handler"
|
||||||
logger2 "chatplus/logger"
|
logger2 "chatplus/logger"
|
||||||
"chatplus/service/mj"
|
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
@ -32,16 +31,14 @@ var logger = logger2.GetLogger()
|
|||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
handler.BaseHandler
|
handler.BaseHandler
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
mjService *mj.Service
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, service *mj.Service) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client) *ChatHandler {
|
||||||
h := ChatHandler{
|
h := ChatHandler{
|
||||||
db: db,
|
db: db,
|
||||||
redis: redis,
|
redis: redis,
|
||||||
mjService: service,
|
|
||||||
}
|
}
|
||||||
h.App = app
|
h.App = app
|
||||||
return &h
|
return &h
|
||||||
|
@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"chatplus/core"
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/service"
|
||||||
"chatplus/service/mj"
|
"chatplus/service/mj"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
@ -11,7 +12,6 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -19,26 +19,22 @@ import (
|
|||||||
|
|
||||||
type MidJourneyHandler struct {
|
type MidJourneyHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
mjService *mj.Service
|
pool *mj.ServicePool
|
||||||
|
snowflake *service.Snowflake
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMidJourneyHandler(
|
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool) *MidJourneyHandler {
|
||||||
app *core.AppServer,
|
|
||||||
client *redis.Client,
|
|
||||||
db *gorm.DB,
|
|
||||||
mjService *mj.Service) *MidJourneyHandler {
|
|
||||||
h := MidJourneyHandler{
|
h := MidJourneyHandler{
|
||||||
redis: client,
|
|
||||||
db: db,
|
db: db,
|
||||||
mjService: mjService,
|
snowflake: snowflake,
|
||||||
|
pool: pool,
|
||||||
}
|
}
|
||||||
h.App = app
|
h.App = app
|
||||||
return &h
|
return &h
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := utils.GetLoginUser(c, h.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.NotAuth(c)
|
resp.NotAuth(c)
|
||||||
@ -50,17 +46,17 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !h.pool.HasAvailableService() {
|
||||||
|
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Image 创建一个绘画任务
|
// Image 创建一个绘画任务
|
||||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||||
if !h.App.Config.MjConfigs[0].Enabled {
|
|
||||||
resp.ERROR(c, "MidJourney service is disabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
@ -80,7 +76,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !h.checkLimits(c) {
|
if !h.preCheck(c) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,9 +117,16 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
|
// generate task id
|
||||||
|
taskId, err := h.snowflake.Next(true)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "error with generate task id: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
job := model.MidJourneyJob{
|
job := model.MidJourneyJob{
|
||||||
Type: types.TaskImage.String(),
|
Type: types.TaskImage.String(),
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
|
TaskId: taskId,
|
||||||
Progress: 0,
|
Progress: 0,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@ -133,24 +136,13 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.mjService.PushTask(types.MjTask{
|
h.pool.PushTask(types.MjTask{
|
||||||
Id: int(job.Id),
|
Id: int(job.Id),
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Src: types.TaskSrcImg,
|
|
||||||
Type: types.TaskImage,
|
Type: types.TaskImage,
|
||||||
Prompt: prompt,
|
Prompt: fmt.Sprintf("%s %s", taskId, prompt),
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
})
|
})
|
||||||
|
|
||||||
var jobVo vo.MidJourneyJob
|
|
||||||
err := utils.CopyObject(job, &jobVo)
|
|
||||||
if err == nil {
|
|
||||||
// 推送任务到前端
|
|
||||||
client := h.mjService.Clients.Get(data.SessionId)
|
|
||||||
if client != nil {
|
|
||||||
utils.ReplyChunkMessage(client, jobVo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,65 +166,23 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !h.checkLimits(c) {
|
if !h.preCheck(c) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
jobId := 0
|
jobId := 0
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
src := types.TaskSrc(data.Src)
|
h.pool.PushTask(types.MjTask{
|
||||||
if src == types.TaskSrcImg {
|
|
||||||
job := model.MidJourneyJob{
|
|
||||||
Type: types.TaskUpscale.String(),
|
|
||||||
UserId: userId,
|
|
||||||
Hash: data.MessageHash,
|
|
||||||
Progress: 0,
|
|
||||||
Prompt: data.Prompt,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
if res := h.db.Create(&job); res.Error == nil {
|
|
||||||
jobId = int(job.Id)
|
|
||||||
} else {
|
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var jobVo vo.MidJourneyJob
|
|
||||||
err := utils.CopyObject(job, &jobVo)
|
|
||||||
if err == nil {
|
|
||||||
// 推送任务到前端
|
|
||||||
client := h.mjService.Clients.Get(data.SessionId)
|
|
||||||
if client != nil {
|
|
||||||
utils.ReplyChunkMessage(client, jobVo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.mjService.PushTask(types.MjTask{
|
|
||||||
Id: jobId,
|
Id: jobId,
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Src: src,
|
|
||||||
Type: types.TaskUpscale,
|
Type: types.TaskUpscale,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
RoleId: data.RoleId,
|
|
||||||
Icon: data.Icon,
|
|
||||||
ChatId: data.ChatId,
|
|
||||||
Index: data.Index,
|
Index: data.Index,
|
||||||
MessageId: data.MessageId,
|
MessageId: data.MessageId,
|
||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
})
|
})
|
||||||
|
|
||||||
if src == types.TaskSrcChat {
|
|
||||||
wsClient := h.App.ChatClients.Get(data.SessionId)
|
|
||||||
if wsClient != nil {
|
|
||||||
content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
|
||||||
utils.ReplyMessage(wsClient, content)
|
|
||||||
if h.mjService.ChatClients.Get(data.SessionId) == nil {
|
|
||||||
h.mjService.ChatClients.Put(data.SessionId, wsClient)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,67 +194,23 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !h.checkLimits(c) {
|
if !h.preCheck(c) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
jobId := 0
|
jobId := 0
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
src := types.TaskSrc(data.Src)
|
h.pool.PushTask(types.MjTask{
|
||||||
if src == types.TaskSrcImg {
|
|
||||||
job := model.MidJourneyJob{
|
|
||||||
Type: types.TaskVariation.String(),
|
|
||||||
UserId: userId,
|
|
||||||
ImgURL: "",
|
|
||||||
Hash: data.MessageHash,
|
|
||||||
Progress: 0,
|
|
||||||
Prompt: data.Prompt,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
if res := h.db.Create(&job); res.Error == nil {
|
|
||||||
jobId = int(job.Id)
|
|
||||||
} else {
|
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var jobVo vo.MidJourneyJob
|
|
||||||
err := utils.CopyObject(job, &jobVo)
|
|
||||||
if err == nil {
|
|
||||||
// 推送任务到前端
|
|
||||||
client := h.mjService.Clients.Get(data.SessionId)
|
|
||||||
if client != nil {
|
|
||||||
utils.ReplyChunkMessage(client, jobVo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.mjService.PushTask(types.MjTask{
|
|
||||||
Id: jobId,
|
Id: jobId,
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Src: src,
|
|
||||||
Type: types.TaskVariation,
|
Type: types.TaskVariation,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
RoleId: data.RoleId,
|
|
||||||
Icon: data.Icon,
|
|
||||||
ChatId: data.ChatId,
|
|
||||||
Index: data.Index,
|
Index: data.Index,
|
||||||
MessageId: data.MessageId,
|
MessageId: data.MessageId,
|
||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
})
|
})
|
||||||
|
|
||||||
if src == types.TaskSrcChat {
|
|
||||||
// 从聊天窗口发送的请求,记录客户端信息
|
|
||||||
wsClient := h.mjService.ChatClients.Get(data.SessionId)
|
|
||||||
if wsClient != nil {
|
|
||||||
content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
|
||||||
utils.ReplyMessage(wsClient, content)
|
|
||||||
if h.mjService.Clients.Get(data.SessionId) == nil {
|
|
||||||
h.mjService.Clients.Put(data.SessionId, wsClient)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -343,19 +249,27 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if job.Progress == -1 {
|
||||||
|
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
|
||||||
|
}
|
||||||
|
|
||||||
if item.Progress < 100 {
|
if item.Progress < 100 {
|
||||||
// 10 分钟还没完成的任务直接删除
|
// 10 分钟还没完成的任务直接删除
|
||||||
if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
|
if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
|
||||||
h.db.Delete(&item)
|
h.db.Delete(&item)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
|
|
||||||
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
|
// 正在运行中任务使用代理访问图片
|
||||||
|
if item.ImgURL == "" && item.OrgURL != "" {
|
||||||
|
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, jobs)
|
resp.SUCCESS(c, jobs)
|
||||||
|
@ -176,7 +176,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
orderNo, err := h.snowflake.Next()
|
orderNo, err := h.snowflake.Next(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||||
return
|
return
|
||||||
|
@ -13,7 +13,8 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
const translatePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
||||||
|
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||||
|
|
||||||
type PromptHandler struct {
|
type PromptHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
@ -47,6 +48,25 @@ type apiErrRes struct {
|
|||||||
} `json:"error"`
|
} `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rewrite translate and rewrite prompt with ChatGPT
|
||||||
|
func (h *PromptHandler) Rewrite(c *gin.Context) {
|
||||||
|
var data struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := h.request(data.Prompt, rewritePromptTemplate)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, content)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *PromptHandler) Translate(c *gin.Context) {
|
func (h *PromptHandler) Translate(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
@ -55,18 +75,28 @@ func (h *PromptHandler) Translate(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
content, err := h.request(data.Prompt, translatePromptTemplate)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
|
||||||
// 获取 OpenAI 的 API KEY
|
// 获取 OpenAI 的 API KEY
|
||||||
var apiKey model.ApiKey
|
var apiKey model.ApiKey
|
||||||
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
|
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "找不到可用 OpenAI API KEY")
|
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := make([]interface{}, 1)
|
messages := make([]interface{}, 1)
|
||||||
messages[0] = types.Message{
|
messages[0] = types.Message{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: fmt.Sprintf(translatePromptTemplate, data.Prompt),
|
Content: fmt.Sprintf(promptTemplate, prompt),
|
||||||
}
|
}
|
||||||
|
|
||||||
var response apiRes
|
var response apiRes
|
||||||
@ -83,9 +113,8 @@ func (h *PromptHandler) Translate(c *gin.Context) {
|
|||||||
SetErrorResult(&errRes).
|
SetErrorResult(&errRes).
|
||||||
SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL)
|
SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil || r.IsErrorState() {
|
||||||
resp.ERROR(c, fmt.Sprintf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message))
|
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c, response.Choices[0].Message.Content)
|
return response.Choices[0].Message.Content, nil
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
h.service.PushTask(types.SdTask{
|
h.service.PushTask(types.SdTask{
|
||||||
Id: int(job.Id),
|
Id: int(job.Id),
|
||||||
SessionId: data.SessionId,
|
SessionId: data.SessionId,
|
||||||
Src: types.TaskSrcImg,
|
|
||||||
Type: types.TaskImage,
|
Type: types.TaskImage,
|
||||||
Prompt: data.Prompt,
|
Prompt: data.Prompt,
|
||||||
Params: params,
|
Params: params,
|
||||||
|
@ -163,8 +163,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// MidJourney 机器人
|
// MidJourney service pool
|
||||||
fx.Provide(mj.NewBot),
|
fx.Provide(mj.NewServicePool),
|
||||||
|
|
||||||
// Stable Diffusion 机器人
|
// Stable Diffusion 机器人
|
||||||
fx.Provide(sd.NewService),
|
fx.Provide(sd.NewService),
|
||||||
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
|
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
|
||||||
@ -341,6 +342,7 @@ func main() {
|
|||||||
fx.Provide(handler.NewPromptHandler),
|
fx.Provide(handler.NewPromptHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
||||||
group := s.Engine.Group("/api/prompt/")
|
group := s.Engine.Group("/api/prompt/")
|
||||||
|
group.POST("rewrite", h.Rewrite)
|
||||||
group.POST("translate", h.Translate)
|
group.POST("translate", h.Translate)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
@ -19,17 +19,18 @@ var logger = logger2.GetLogger()
|
|||||||
type Bot struct {
|
type Bot struct {
|
||||||
config *types.MidJourneyConfig
|
config *types.MidJourneyConfig
|
||||||
bot *discordgo.Session
|
bot *discordgo.Session
|
||||||
|
name string
|
||||||
service *Service
|
service *Service
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) {
|
||||||
discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken)
|
discord, err := discordgo.New("Bot " + config.BotToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.ProxyURL != "" {
|
if proxy != "" {
|
||||||
proxy, _ := url.Parse(config.ProxyURL)
|
proxy, _ := url.Parse(proxy)
|
||||||
discord.Client = &http.Client{
|
discord.Client = &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
Proxy: http.ProxyURL(proxy),
|
Proxy: http.ProxyURL(proxy),
|
||||||
@ -41,8 +42,9 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Bot{
|
return &Bot{
|
||||||
config: &config.MjConfigs,
|
config: config,
|
||||||
bot: discord,
|
bot: discord,
|
||||||
|
name: name,
|
||||||
service: service,
|
service: service,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -52,13 +54,13 @@ func (b *Bot) Run() error {
|
|||||||
b.bot.AddHandler(b.messageCreate)
|
b.bot.AddHandler(b.messageCreate)
|
||||||
b.bot.AddHandler(b.messageUpdate)
|
b.bot.AddHandler(b.messageUpdate)
|
||||||
|
|
||||||
logger.Info("Starting MidJourney Bot...")
|
logger.Infof("Starting MidJourney %s", b.name)
|
||||||
err := b.bot.Open()
|
err := b.bot.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error opening Discord connection:", err)
|
logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
logger.Info("Starting MidJourney Bot successfully!")
|
logger.Infof("Starting MidJourney %s successfully!", b.name)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ package mj
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/utils"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"time"
|
"time"
|
||||||
@ -18,9 +17,10 @@ type Client struct {
|
|||||||
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
|
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
|
||||||
client := req.C().SetTimeout(10 * time.Second)
|
client := req.C().SetTimeout(10 * time.Second)
|
||||||
// set proxy URL
|
// set proxy URL
|
||||||
if utils.IsEmptyValue(proxy) {
|
if proxy != "" {
|
||||||
client.SetProxyURL(proxy)
|
client.SetProxyURL(proxy)
|
||||||
}
|
}
|
||||||
|
logger.Info(proxy)
|
||||||
return &Client{client: client, config: config}
|
return &Client{client: client, config: config}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,35 +4,63 @@ import (
|
|||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ServicePool Mj service pool
|
// ServicePool Mj service pool
|
||||||
type ServicePool struct {
|
type ServicePool struct {
|
||||||
services []Service
|
services []*Service
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||||
|
services := make([]*Service, 0)
|
||||||
|
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
|
||||||
// create mj client and service
|
// create mj client and service
|
||||||
for _, config := range appConfig.MjConfigs {
|
for k, config := range appConfig.MjConfigs {
|
||||||
if config.Enabled == false {
|
if config.Enabled == false {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// create mj client
|
// create mj client
|
||||||
client := NewClient(&config, appConfig.ProxyURL)
|
client := NewClient(&config, appConfig.ProxyURL)
|
||||||
|
|
||||||
|
name := fmt.Sprintf("MjService-%d", k)
|
||||||
// create mj service
|
// create mj service
|
||||||
service := NewService()
|
service := NewService(name, queue, 4, 600, db, client, manager, appConfig)
|
||||||
|
botName := fmt.Sprintf("MjBot-%d", k)
|
||||||
|
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bot.Run()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// run mj service
|
||||||
|
go func() {
|
||||||
|
service.Run()
|
||||||
|
}()
|
||||||
|
|
||||||
|
services = append(services, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ServicePool{
|
return &ServicePool{
|
||||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
taskQueue: queue,
|
||||||
|
services: services,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PushTask push a new mj task in to task queue
|
||||||
func (p *ServicePool) PushTask(task types.MjTask) {
|
func (p *ServicePool) PushTask(task types.MjTask) {
|
||||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||||
p.taskQueue.RPush(task)
|
p.taskQueue.RPush(task)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasAvailableService check if has available mj service in pool
|
||||||
|
func (p *ServicePool) HasAvailableService() bool {
|
||||||
|
return len(p.services) > 0
|
||||||
|
}
|
||||||
|
@ -27,16 +27,17 @@ type Service struct {
|
|||||||
snowflake *service.Snowflake
|
snowflake *service.Snowflake
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
name: name,
|
name: name,
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: queue,
|
taskQueue: queue,
|
||||||
client: client,
|
client: client,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
taskTimeout: timeout,
|
taskTimeout: timeout,
|
||||||
proxyURL: config.ProxyURL,
|
maxHandleTaskNum: maxTaskNum,
|
||||||
taskStartTimes: make(map[int]time.Time, 0),
|
proxyURL: config.ProxyURL,
|
||||||
|
taskStartTimes: make(map[int]time.Time, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,7 +59,7 @@ func (s *Service) Run() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("handle a new MidJourney task: %+v", task)
|
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
|
||||||
switch task.Type {
|
switch task.Type {
|
||||||
case types.TaskImage:
|
case types.TaskImage:
|
||||||
err = s.client.Imagine(task.Prompt)
|
err = s.client.Imagine(task.Prompt)
|
||||||
@ -92,11 +93,14 @@ func (s *Service) canHandleTask() bool {
|
|||||||
return handledNum < s.maxHandleTaskNum
|
return handledNum < s.maxHandleTaskNum
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove the expired tasks
|
||||||
func (s *Service) checkTasks() {
|
func (s *Service) checkTasks() {
|
||||||
for k, t := range s.taskStartTimes {
|
for k, t := range s.taskStartTimes {
|
||||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||||
delete(s.taskStartTimes, k)
|
delete(s.taskStartTimes, k)
|
||||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||||
|
// delete task from database
|
||||||
|
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -121,15 +125,17 @@ func (s *Service) Notify(data CBReq) {
|
|||||||
job.Progress = data.Progress
|
job.Progress = data.Progress
|
||||||
job.Prompt = data.Prompt
|
job.Prompt = data.Prompt
|
||||||
job.Hash = data.Image.Hash
|
job.Hash = data.Image.Hash
|
||||||
job.OrgURL = data.Image.URL // save origin image
|
job.OrgURL = data.Image.URL
|
||||||
|
|
||||||
// upload image
|
// upload image
|
||||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
if data.Status == Finished {
|
||||||
if err != nil {
|
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||||
logger.Error("error with download img: ", err.Error())
|
if err != nil {
|
||||||
return
|
logger.Error("error with download img: ", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
job.ImgURL = imgURL
|
||||||
}
|
}
|
||||||
job.ImgURL = imgURL
|
|
||||||
|
|
||||||
res = s.db.Updates(&job)
|
res = s.db.Updates(&job)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
|
@ -23,7 +23,7 @@ func NewSnowflake() *Snowflake {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Next 生成一个新的唯一ID
|
// Next 生成一个新的唯一ID
|
||||||
func (s *Snowflake) Next() (string, error) {
|
func (s *Snowflake) Next(raw bool) (string, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@ -43,6 +43,9 @@ func (s *Snowflake) Next() (string, error) {
|
|||||||
|
|
||||||
s.lastTimestamp = timestamp
|
s.lastTimestamp = timestamp
|
||||||
id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence)
|
id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence)
|
||||||
|
if raw {
|
||||||
|
return fmt.Sprintf("%d", id), nil
|
||||||
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil
|
return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ type MidJourneyJob struct {
|
|||||||
Id uint `json:"id"`
|
Id uint `json:"id"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
MessageId string `json:"message_id"`
|
MessageId string `json:"message_id"`
|
||||||
ReferenceId string `json:"reference_id"`
|
ReferenceId string `json:"reference_id"`
|
||||||
ImgURL string `json:"img_url"`
|
ImgURL string `json:"img_url"`
|
||||||
|
@ -244,12 +244,30 @@
|
|||||||
</el-icon>
|
</el-icon>
|
||||||
</el-tooltip>
|
</el-tooltip>
|
||||||
</div>
|
</div>
|
||||||
<el-button type="success" @click="translatePrompt">
|
<div>
|
||||||
<el-icon style="margin-right: 6px;font-size: 18px;">
|
<el-button type="primary" @click="translatePrompt">
|
||||||
<Refresh/>
|
<el-icon style="margin-right: 6px;font-size: 18px;">
|
||||||
</el-icon>
|
<Refresh/>
|
||||||
翻译
|
</el-icon>
|
||||||
</el-button>
|
翻译
|
||||||
|
</el-button>
|
||||||
|
|
||||||
|
<el-tooltip
|
||||||
|
class="box-item"
|
||||||
|
effect="light"
|
||||||
|
raw-content
|
||||||
|
content="使用 AI 翻译并重写提示词,<br/>增加更多细节,风格等描述"
|
||||||
|
placement="top-end"
|
||||||
|
>
|
||||||
|
<el-button type="success" @click="rewritePrompt">
|
||||||
|
<el-icon style="margin-right: 6px;font-size: 18px;">
|
||||||
|
<Refresh/>
|
||||||
|
</el-icon>
|
||||||
|
翻译并重写
|
||||||
|
</el-button>
|
||||||
|
</el-tooltip>
|
||||||
|
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -432,8 +450,7 @@ import ItemList from "@/components/ItemList.vue";
|
|||||||
import Clipboard from "clipboard";
|
import Clipboard from "clipboard";
|
||||||
import {checkSession} from "@/action/session";
|
import {checkSession} from "@/action/session";
|
||||||
import {useRouter} from "vue-router";
|
import {useRouter} from "vue-router";
|
||||||
import {getSessionId, getUserToken} from "@/store/session";
|
import {getSessionId} from "@/store/session";
|
||||||
import {removeArrayItem} from "@/utils/libs";
|
|
||||||
|
|
||||||
const listBoxHeight = ref(window.innerHeight - 40)
|
const listBoxHeight = ref(window.innerHeight - 40)
|
||||||
const mjBoxHeight = ref(window.innerHeight - 150)
|
const mjBoxHeight = ref(window.innerHeight - 150)
|
||||||
@ -504,79 +521,22 @@ const socket = ref(null)
|
|||||||
const imgCalls = ref(0)
|
const imgCalls = ref(0)
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
|
|
||||||
// const connect = () => {
|
const rewritePrompt = () => {
|
||||||
// let host = process.env.VUE_APP_WS_HOST
|
loading.value = true
|
||||||
// if (host === '') {
|
httpPost("/api/prompt/rewrite", {"prompt": params.value.prompt}).then(res => {
|
||||||
// if (location.protocol === 'https:') {
|
params.value.prompt = res.data
|
||||||
// host = 'wss://' + location.host;
|
loading.value = false
|
||||||
// } else {
|
}).catch(e => {
|
||||||
// host = 'ws://' + location.host;
|
ElMessage.error("翻译失败:" + e.message)
|
||||||
// }
|
})
|
||||||
// }
|
}
|
||||||
// const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
|
|
||||||
// _socket.addEventListener('open', () => {
|
|
||||||
// socket.value = _socket;
|
|
||||||
// });
|
|
||||||
//
|
|
||||||
// _socket.addEventListener('message', event => {
|
|
||||||
// if (event.data instanceof Blob) {
|
|
||||||
// const reader = new FileReader();
|
|
||||||
// reader.readAsText(event.data, "UTF-8");
|
|
||||||
// reader.onload = () => {
|
|
||||||
// const data = JSON.parse(String(reader.result));
|
|
||||||
// let isNew = true
|
|
||||||
// if (data.progress === 100) {
|
|
||||||
// for (let i = 0; i < finishedJobs.value.length; i++) {
|
|
||||||
// if (finishedJobs.value[i].id === data.id) {
|
|
||||||
// isNew = false
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// for (let i = 0; i < runningJobs.value.length; i++) {
|
|
||||||
// if (runningJobs.value[i].id === data.id) {
|
|
||||||
// runningJobs.value.splice(i, 1)
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// if (isNew) {
|
|
||||||
// finishedJobs.value.unshift(data)
|
|
||||||
// }
|
|
||||||
// } else if (data.progress === -1) { // 任务执行失败
|
|
||||||
// ElNotification({
|
|
||||||
// title: '任务执行失败',
|
|
||||||
// message: "提示词:" + data['prompt'],
|
|
||||||
// type: 'error',
|
|
||||||
// })
|
|
||||||
// runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
|
|
||||||
//
|
|
||||||
// } else {
|
|
||||||
// for (let i = 0; i < runningJobs.value.length; i++) {
|
|
||||||
// if (runningJobs.value[i].id === data.id) {
|
|
||||||
// isNew = false
|
|
||||||
// runningJobs.value[i] = data
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// if (isNew) {
|
|
||||||
// runningJobs.value.push(data)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
//
|
|
||||||
// _socket.addEventListener('close', () => {
|
|
||||||
// ElMessage.error("Websocket 已经断开,正在重新连接服务器")
|
|
||||||
// connect()
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
|
|
||||||
const translatePrompt = () => {
|
const translatePrompt = () => {
|
||||||
loading.value = true
|
loading.value = true
|
||||||
httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => {
|
httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => {
|
||||||
params.value.prompt = res.data
|
params.value.prompt = res.data
|
||||||
loading.value = false
|
loading.value = false
|
||||||
}).then(e => {
|
}).catch(e => {
|
||||||
ElMessage.error("翻译失败:" + e.message)
|
ElMessage.error("翻译失败:" + e.message)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -584,22 +544,10 @@ const translatePrompt = () => {
|
|||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
checkSession().then(user => {
|
checkSession().then(user => {
|
||||||
imgCalls.value = user['img_calls']
|
imgCalls.value = user['img_calls']
|
||||||
// 获取运行中的任务
|
|
||||||
httpGet(`/api/mj/jobs?status=0&user_id=${user['id']}`).then(res => {
|
|
||||||
runningJobs.value = res.data
|
|
||||||
}).catch(e => {
|
|
||||||
ElMessage.error("获取任务失败:" + e.message)
|
|
||||||
})
|
|
||||||
|
|
||||||
// 获取已完成的任务
|
fetchRunningJobs(user.id)
|
||||||
httpGet(`/api/mj/jobs?status=1&user_id=${user['id']}`).then(res => {
|
fetchFinishJobs(user.id)
|
||||||
finishedJobs.value = res.data
|
|
||||||
}).catch(e => {
|
|
||||||
ElMessage.error("获取任务失败:" + e.message)
|
|
||||||
})
|
|
||||||
|
|
||||||
// 连接 socket
|
|
||||||
connect();
|
|
||||||
}).catch(() => {
|
}).catch(() => {
|
||||||
router.push('/login')
|
router.push('/login')
|
||||||
});
|
});
|
||||||
@ -614,6 +562,41 @@ onMounted(() => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 获取运行中的任务
|
||||||
|
const fetchRunningJobs = (userId) => {
|
||||||
|
httpGet(`/api/mj/jobs?status=0&user_id=${userId}`).then(res => {
|
||||||
|
const jobs = res.data
|
||||||
|
const _jobs = []
|
||||||
|
for (let i = 0; i < jobs.length; i++) {
|
||||||
|
if (jobs[i].progress === -1) {
|
||||||
|
ElNotification({
|
||||||
|
title: '任务执行失败',
|
||||||
|
message: "任务ID:" + jobs[i]['task_id'],
|
||||||
|
type: 'error',
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_jobs.push(jobs[i])
|
||||||
|
}
|
||||||
|
runningJobs.value = _jobs
|
||||||
|
|
||||||
|
setTimeout(() => fetchRunningJobs(userId), 10000)
|
||||||
|
|
||||||
|
}).catch(e => {
|
||||||
|
ElMessage.error("获取任务失败:" + e.message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const fetchFinishJobs = (userId) => {
|
||||||
|
// 获取已完成的任务
|
||||||
|
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
|
||||||
|
finishedJobs.value = res.data
|
||||||
|
setTimeout(() => fetchFinishJobs(userId), 10000)
|
||||||
|
}).catch(e => {
|
||||||
|
ElMessage.error("获取任务失败:" + e.message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// 切换图片比例
|
// 切换图片比例
|
||||||
const changeRate = (item) => {
|
const changeRate = (item) => {
|
||||||
params.value.rate = item.value
|
params.value.rate = item.value
|
||||||
@ -676,7 +659,6 @@ const variation = (index, item) => {
|
|||||||
const send = (url, index, item) => {
|
const send = (url, index, item) => {
|
||||||
httpPost(url, {
|
httpPost(url, {
|
||||||
index: index,
|
index: index,
|
||||||
src: "img",
|
|
||||||
message_id: item.message_id,
|
message_id: item.message_id,
|
||||||
message_hash: item.hash,
|
message_hash: item.hash,
|
||||||
session_id: getSessionId(),
|
session_id: getSessionId(),
|
||||||
|
@ -101,7 +101,7 @@ const login = function () {
|
|||||||
|
|
||||||
httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => {
|
httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => {
|
||||||
setUserToken(res.data)
|
setUserToken(res.data)
|
||||||
if (prevRoute.path === '') {
|
if (prevRoute.path === '' || prevRoute.path === '/register') {
|
||||||
if (isMobile()) {
|
if (isMobile()) {
|
||||||
router.push('/mobile')
|
router.push('/mobile')
|
||||||
} else {
|
} else {
|
||||||
|
@ -127,14 +127,6 @@ import {isMobile} from "@/utils/libs";
|
|||||||
import {setUserToken} from "@/store/session";
|
import {setUserToken} from "@/store/session";
|
||||||
import {checkSession} from "@/action/session";
|
import {checkSession} from "@/action/session";
|
||||||
|
|
||||||
checkSession().then(() => {
|
|
||||||
if (isMobile()) {
|
|
||||||
router.push('/mobile')
|
|
||||||
} else {
|
|
||||||
router.push('/chat')
|
|
||||||
}
|
|
||||||
}).catch(() => {
|
|
||||||
})
|
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const title = ref('ChatGPT-PLUS 用户注册');
|
const title = ref('ChatGPT-PLUS 用户注册');
|
||||||
const formData = ref({
|
const formData = ref({
|
||||||
|
Loading…
Reference in New Issue
Block a user