diff --git a/api/core/config.go b/api/core/config.go index 3ec3e76d..632a7b68 100644 --- a/api/core/config.go +++ b/api/core/config.go @@ -33,7 +33,6 @@ func NewDefaultConfig() *types.AppConfig { BasePath: "./static/upload", }, }, - MjConfigs: types.MidJourneyConfig{Enabled: false}, SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"}, WeChatBot: false, AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 1a25fea3..f81df2a9 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -6,7 +6,6 @@ import ( "chatplus/core/types" "chatplus/handler" logger2 "chatplus/logger" - "chatplus/service/mj" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" @@ -32,16 +31,14 @@ var logger = logger2.GetLogger() type ChatHandler struct { handler.BaseHandler - db *gorm.DB - redis *redis.Client - mjService *mj.Service + db *gorm.DB + redis *redis.Client } -func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, service *mj.Service) *ChatHandler { +func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client) *ChatHandler { h := ChatHandler{ - db: db, - redis: redis, - mjService: service, + db: db, + redis: redis, } h.App = app return &h diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 1acce9c0..2adbf7df 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -3,6 +3,7 @@ package handler import ( "chatplus/core" "chatplus/core/types" + "chatplus/service" "chatplus/service/mj" "chatplus/store/model" "chatplus/store/vo" @@ -11,7 +12,6 @@ import ( "encoding/base64" "fmt" "github.com/gin-gonic/gin" - "github.com/go-redis/redis/v8" "gorm.io/gorm" "strings" "time" @@ -19,26 +19,22 @@ import ( type MidJourneyHandler struct { BaseHandler - redis *redis.Client db *gorm.DB - mjService *mj.Service + pool *mj.ServicePool + snowflake *service.Snowflake } -func NewMidJourneyHandler( - app *core.AppServer, - client *redis.Client, - db *gorm.DB, - mjService *mj.Service) *MidJourneyHandler { +func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool) *MidJourneyHandler { h := MidJourneyHandler{ - redis: client, db: db, - mjService: mjService, + snowflake: snowflake, + pool: pool, } h.App = app return &h } -func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool { +func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { user, err := utils.GetLoginUser(c, h.db) if err != nil { resp.NotAuth(c) @@ -50,17 +46,17 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool { return false } + if !h.pool.HasAvailableService() { + resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!") + return false + } + return true } // Image 创建一个绘画任务 func (h *MidJourneyHandler) Image(c *gin.Context) { - if !h.App.Config.MjConfigs[0].Enabled { - resp.ERROR(c, "MidJourney service is disabled") - return - } - var data struct { SessionId string `json:"session_id"` Prompt string `json:"prompt"` @@ -80,7 +76,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - if !h.checkLimits(c) { + if !h.preCheck(c) { return } @@ -121,9 +117,16 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { idValue, _ := c.Get(types.LoginUserID) userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + // generate task id + taskId, err := h.snowflake.Next(true) + if err != nil { + resp.ERROR(c, "error with generate task id: "+err.Error()) + return + } job := model.MidJourneyJob{ Type: types.TaskImage.String(), UserId: userId, + TaskId: taskId, Progress: 0, Prompt: prompt, CreatedAt: time.Now(), @@ -133,24 +136,13 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { return } - h.mjService.PushTask(types.MjTask{ + h.pool.PushTask(types.MjTask{ Id: int(job.Id), SessionId: data.SessionId, - Src: types.TaskSrcImg, Type: types.TaskImage, - Prompt: prompt, + Prompt: fmt.Sprintf("%s %s", taskId, prompt), UserId: userId, }) - - var jobVo vo.MidJourneyJob - err := utils.CopyObject(job, &jobVo) - if err == nil { - // 推送任务到前端 - client := h.mjService.Clients.Get(data.SessionId) - if client != nil { - utils.ReplyChunkMessage(client, jobVo) - } - } resp.SUCCESS(c) } @@ -174,65 +166,23 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } - if !h.checkLimits(c) { + if !h.preCheck(c) { return } idValue, _ := c.Get(types.LoginUserID) jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) - src := types.TaskSrc(data.Src) - if src == types.TaskSrcImg { - job := model.MidJourneyJob{ - Type: types.TaskUpscale.String(), - UserId: userId, - Hash: data.MessageHash, - Progress: 0, - Prompt: data.Prompt, - CreatedAt: time.Now(), - } - if res := h.db.Create(&job); res.Error == nil { - jobId = int(job.Id) - } else { - resp.ERROR(c, "添加任务失败:"+res.Error.Error()) - return - } - - var jobVo vo.MidJourneyJob - err := utils.CopyObject(job, &jobVo) - if err == nil { - // 推送任务到前端 - client := h.mjService.Clients.Get(data.SessionId) - if client != nil { - utils.ReplyChunkMessage(client, jobVo) - } - } - } - h.mjService.PushTask(types.MjTask{ + h.pool.PushTask(types.MjTask{ Id: jobId, SessionId: data.SessionId, - Src: src, Type: types.TaskUpscale, Prompt: data.Prompt, UserId: userId, - RoleId: data.RoleId, - Icon: data.Icon, - ChatId: data.ChatId, Index: data.Index, MessageId: data.MessageId, MessageHash: data.MessageHash, }) - - if src == types.TaskSrcChat { - wsClient := h.App.ChatClients.Get(data.SessionId) - if wsClient != nil { - content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) - utils.ReplyMessage(wsClient, content) - if h.mjService.ChatClients.Get(data.SessionId) == nil { - h.mjService.ChatClients.Put(data.SessionId, wsClient) - } - } - } resp.SUCCESS(c) } @@ -244,67 +194,23 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { return } - if !h.checkLimits(c) { + if !h.preCheck(c) { return } idValue, _ := c.Get(types.LoginUserID) jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) - src := types.TaskSrc(data.Src) - if src == types.TaskSrcImg { - job := model.MidJourneyJob{ - Type: types.TaskVariation.String(), - UserId: userId, - ImgURL: "", - Hash: data.MessageHash, - Progress: 0, - Prompt: data.Prompt, - CreatedAt: time.Now(), - } - if res := h.db.Create(&job); res.Error == nil { - jobId = int(job.Id) - } else { - resp.ERROR(c, "添加任务失败:"+res.Error.Error()) - return - } - - var jobVo vo.MidJourneyJob - err := utils.CopyObject(job, &jobVo) - if err == nil { - // 推送任务到前端 - client := h.mjService.Clients.Get(data.SessionId) - if client != nil { - utils.ReplyChunkMessage(client, jobVo) - } - } - } - h.mjService.PushTask(types.MjTask{ + h.pool.PushTask(types.MjTask{ Id: jobId, SessionId: data.SessionId, - Src: src, Type: types.TaskVariation, Prompt: data.Prompt, UserId: userId, - RoleId: data.RoleId, - Icon: data.Icon, - ChatId: data.ChatId, Index: data.Index, MessageId: data.MessageId, MessageHash: data.MessageHash, }) - - if src == types.TaskSrcChat { - // 从聊天窗口发送的请求,记录客户端信息 - wsClient := h.mjService.ChatClients.Get(data.SessionId) - if wsClient != nil { - content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) - utils.ReplyMessage(wsClient, content) - if h.mjService.Clients.Get(data.SessionId) == nil { - h.mjService.Clients.Put(data.SessionId, wsClient) - } - } - } resp.SUCCESS(c) } @@ -343,19 +249,27 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { if err != nil { continue } + + if job.Progress == -1 { + h.db.Delete(&model.MidJourneyJob{Id: job.Id}) + } + if item.Progress < 100 { // 10 分钟还没完成的任务直接删除 if time.Now().Sub(item.CreatedAt) > time.Minute*10 { h.db.Delete(&item) continue } - if item.ImgURL != "" { // 正在运行中任务使用代理访问图片 - image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) + + // 正在运行中任务使用代理访问图片 + if item.ImgURL == "" && item.OrgURL != "" { + image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) if err == nil { job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) } } } + jobs = append(jobs, job) } resp.SUCCESS(c, jobs) diff --git a/api/handler/payment_handler.go b/api/handler/payment_handler.go index a75e2ccd..1b819024 100644 --- a/api/handler/payment_handler.go +++ b/api/handler/payment_handler.go @@ -176,7 +176,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) { return } - orderNo, err := h.snowflake.Next() + orderNo, err := h.snowflake.Next(false) if err != nil { resp.ERROR(c, "error with generate trade no: "+err.Error()) return diff --git a/api/handler/prompt_handler.go b/api/handler/prompt_handler.go index 8d0b1b5c..1a0eb36c 100644 --- a/api/handler/prompt_handler.go +++ b/api/handler/prompt_handler.go @@ -13,7 +13,8 @@ import ( "gorm.io/gorm" ) -const translatePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" +const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" +const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" type PromptHandler struct { BaseHandler @@ -47,6 +48,25 @@ type apiErrRes struct { } `json:"error"` } +// Rewrite translate and rewrite prompt with ChatGPT +func (h *PromptHandler) Rewrite(c *gin.Context) { + var data struct { + Prompt string `json:"prompt"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + content, err := h.request(data.Prompt, rewritePromptTemplate) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, content) +} + func (h *PromptHandler) Translate(c *gin.Context) { var data struct { Prompt string `json:"prompt"` @@ -55,18 +75,28 @@ func (h *PromptHandler) Translate(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } + + content, err := h.request(data.Prompt, translatePromptTemplate) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, content) +} + +func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) { // 获取 OpenAI 的 API KEY var apiKey model.ApiKey res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey) if res.Error != nil { - resp.ERROR(c, "找不到可用 OpenAI API KEY") - return + return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error) } messages := make([]interface{}, 1) messages[0] = types.Message{ Role: "user", - Content: fmt.Sprintf(translatePromptTemplate, data.Prompt), + Content: fmt.Sprintf(promptTemplate, prompt), } var response apiRes @@ -83,9 +113,8 @@ func (h *PromptHandler) Translate(c *gin.Context) { SetErrorResult(&errRes). SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL) if err != nil || r.IsErrorState() { - resp.ERROR(c, fmt.Sprintf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)) - return + return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message) } - resp.SUCCESS(c, response.Choices[0].Message.Content) + return response.Choices[0].Message.Content, nil } diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index c98eeb05..83768252 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -141,7 +141,6 @@ func (h *SdJobHandler) Image(c *gin.Context) { h.service.PushTask(types.SdTask{ Id: int(job.Id), SessionId: data.SessionId, - Src: types.TaskSrcImg, Type: types.TaskImage, Prompt: data.Prompt, Params: params, diff --git a/api/main.go b/api/main.go index f8642af5..46626b13 100644 --- a/api/main.go +++ b/api/main.go @@ -163,8 +163,9 @@ func main() { } }), - // MidJourney 机器人 - fx.Provide(mj.NewBot), + // MidJourney service pool + fx.Provide(mj.NewServicePool), + // Stable Diffusion 机器人 fx.Provide(sd.NewService), fx.Invoke(func(config *types.AppConfig, service *sd.Service) { @@ -341,6 +342,7 @@ func main() { fx.Provide(handler.NewPromptHandler), fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) { group := s.Engine.Group("/api/prompt/") + group.POST("rewrite", h.Rewrite) group.POST("translate", h.Translate) }), diff --git a/api/service/mj/bot.go b/api/service/mj/bot.go index 1daa7f60..e4412b83 100644 --- a/api/service/mj/bot.go +++ b/api/service/mj/bot.go @@ -19,17 +19,18 @@ var logger = logger2.GetLogger() type Bot struct { config *types.MidJourneyConfig bot *discordgo.Session + name string service *Service } -func NewBot(config *types.AppConfig, service *Service) (*Bot, error) { - discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken) +func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) { + discord, err := discordgo.New("Bot " + config.BotToken) if err != nil { return nil, err } - if config.ProxyURL != "" { - proxy, _ := url.Parse(config.ProxyURL) + if proxy != "" { + proxy, _ := url.Parse(proxy) discord.Client = &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxy), @@ -41,8 +42,9 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) { } return &Bot{ - config: &config.MjConfigs, + config: config, bot: discord, + name: name, service: service, }, nil } @@ -52,13 +54,13 @@ func (b *Bot) Run() error { b.bot.AddHandler(b.messageCreate) b.bot.AddHandler(b.messageUpdate) - logger.Info("Starting MidJourney Bot...") + logger.Infof("Starting MidJourney %s", b.name) err := b.bot.Open() if err != nil { - logger.Error("Error opening Discord connection:", err) + logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err) return err } - logger.Info("Starting MidJourney Bot successfully!") + logger.Infof("Starting MidJourney %s successfully!", b.name) return nil } diff --git a/api/service/mj/client.go b/api/service/mj/client.go index 5d20df80..584e93bd 100644 --- a/api/service/mj/client.go +++ b/api/service/mj/client.go @@ -2,7 +2,6 @@ package mj import ( "chatplus/core/types" - "chatplus/utils" "fmt" "github.com/imroc/req/v3" "time" @@ -18,9 +17,10 @@ type Client struct { func NewClient(config *types.MidJourneyConfig, proxy string) *Client { client := req.C().SetTimeout(10 * time.Second) // set proxy URL - if utils.IsEmptyValue(proxy) { + if proxy != "" { client.SetProxyURL(proxy) } + logger.Info(proxy) return &Client{client: client, config: config} } diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index 1711a4d8..ad4e0c9d 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -4,35 +4,63 @@ import ( "chatplus/core/types" "chatplus/service/oss" "chatplus/store" + "fmt" "github.com/go-redis/redis/v8" "gorm.io/gorm" ) // ServicePool Mj service pool type ServicePool struct { - services []Service + services []*Service taskQueue *store.RedisQueue } func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { + services := make([]*Service, 0) + queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) // create mj client and service - for _, config := range appConfig.MjConfigs { + for k, config := range appConfig.MjConfigs { if config.Enabled == false { continue } // create mj client client := NewClient(&config, appConfig.ProxyURL) + name := fmt.Sprintf("MjService-%d", k) // create mj service - service := NewService() + service := NewService(name, queue, 4, 600, db, client, manager, appConfig) + botName := fmt.Sprintf("MjBot-%d", k) + bot, err := NewBot(botName, appConfig.ProxyURL, &config, service) + if err != nil { + continue + } + + err = bot.Run() + if err != nil { + continue + } + + // run mj service + go func() { + service.Run() + }() + + services = append(services, service) } return &ServicePool{ - taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), + taskQueue: queue, + services: services, } } +// PushTask push a new mj task in to task queue func (p *ServicePool) PushTask(task types.MjTask) { logger.Debugf("add a new MidJourney task to the task list: %+v", task) p.taskQueue.RPush(task) } + +// HasAvailableService check if has available mj service in pool +func (p *ServicePool) HasAvailableService() bool { + return len(p.services) > 0 +} diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 6b151509..e991ef01 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -27,16 +27,17 @@ type Service struct { snowflake *service.Snowflake } -func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service { +func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service { return &Service{ - name: name, - db: db, - taskQueue: queue, - client: client, - uploadManager: manager, - taskTimeout: timeout, - proxyURL: config.ProxyURL, - taskStartTimes: make(map[int]time.Time, 0), + name: name, + db: db, + taskQueue: queue, + client: client, + uploadManager: manager, + taskTimeout: timeout, + maxHandleTaskNum: maxTaskNum, + proxyURL: config.ProxyURL, + taskStartTimes: make(map[int]time.Time, 0), } } @@ -58,7 +59,7 @@ func (s *Service) Run() { continue } - logger.Infof("handle a new MidJourney task: %+v", task) + logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) switch task.Type { case types.TaskImage: err = s.client.Imagine(task.Prompt) @@ -92,11 +93,14 @@ func (s *Service) canHandleTask() bool { return handledNum < s.maxHandleTaskNum } +// remove the expired tasks func (s *Service) checkTasks() { for k, t := range s.taskStartTimes { if time.Now().Unix()-t.Unix() > s.taskTimeout { delete(s.taskStartTimes, k) atomic.AddInt32(&s.handledTaskNum, -1) + // delete task from database + s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") } } } @@ -121,15 +125,17 @@ func (s *Service) Notify(data CBReq) { job.Progress = data.Progress job.Prompt = data.Prompt job.Hash = data.Image.Hash - job.OrgURL = data.Image.URL // save origin image + job.OrgURL = data.Image.URL // upload image - imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) - if err != nil { - logger.Error("error with download img: ", err.Error()) - return + if data.Status == Finished { + imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) + if err != nil { + logger.Error("error with download img: ", err.Error()) + return + } + job.ImgURL = imgURL } - job.ImgURL = imgURL res = s.db.Updates(&job) if res.Error != nil { diff --git a/api/service/snowflake.go b/api/service/snowflake.go index 66416bef..a4c19193 100644 --- a/api/service/snowflake.go +++ b/api/service/snowflake.go @@ -23,7 +23,7 @@ func NewSnowflake() *Snowflake { } // Next 生成一个新的唯一ID -func (s *Snowflake) Next() (string, error) { +func (s *Snowflake) Next(raw bool) (string, error) { s.mu.Lock() defer s.mu.Unlock() @@ -43,6 +43,9 @@ func (s *Snowflake) Next() (string, error) { s.lastTimestamp = timestamp id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence) + if raw { + return fmt.Sprintf("%d", id), nil + } now := time.Now() return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil } diff --git a/api/store/vo/mj_job.go b/api/store/vo/mj_job.go index bdb732fd..cbb9bbc6 100644 --- a/api/store/vo/mj_job.go +++ b/api/store/vo/mj_job.go @@ -6,6 +6,7 @@ type MidJourneyJob struct { Id uint `json:"id"` Type string `json:"type"` UserId int `json:"user_id"` + TaskId string `json:"task_id"` MessageId string `json:"message_id"` ReferenceId string `json:"reference_id"` ImgURL string `json:"img_url"` diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 17e11826..0bdd9afa 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -244,12 +244,30 @@ - - - - - 翻译 - +
+ + + + + 翻译 + + + + + + + + 翻译并重写 + + + +
@@ -432,8 +450,7 @@ import ItemList from "@/components/ItemList.vue"; import Clipboard from "clipboard"; import {checkSession} from "@/action/session"; import {useRouter} from "vue-router"; -import {getSessionId, getUserToken} from "@/store/session"; -import {removeArrayItem} from "@/utils/libs"; +import {getSessionId} from "@/store/session"; const listBoxHeight = ref(window.innerHeight - 40) const mjBoxHeight = ref(window.innerHeight - 150) @@ -504,79 +521,22 @@ const socket = ref(null) const imgCalls = ref(0) const loading = ref(false) -// const connect = () => { -// let host = process.env.VUE_APP_WS_HOST -// if (host === '') { -// if (location.protocol === 'https:') { -// host = 'wss://' + location.host; -// } else { -// host = 'ws://' + location.host; -// } -// } -// const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`); -// _socket.addEventListener('open', () => { -// socket.value = _socket; -// }); -// -// _socket.addEventListener('message', event => { -// if (event.data instanceof Blob) { -// const reader = new FileReader(); -// reader.readAsText(event.data, "UTF-8"); -// reader.onload = () => { -// const data = JSON.parse(String(reader.result)); -// let isNew = true -// if (data.progress === 100) { -// for (let i = 0; i < finishedJobs.value.length; i++) { -// if (finishedJobs.value[i].id === data.id) { -// isNew = false -// break -// } -// } -// for (let i = 0; i < runningJobs.value.length; i++) { -// if (runningJobs.value[i].id === data.id) { -// runningJobs.value.splice(i, 1) -// break -// } -// } -// if (isNew) { -// finishedJobs.value.unshift(data) -// } -// } else if (data.progress === -1) { // 任务执行失败 -// ElNotification({ -// title: '任务执行失败', -// message: "提示词:" + data['prompt'], -// type: 'error', -// }) -// runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id) -// -// } else { -// for (let i = 0; i < runningJobs.value.length; i++) { -// if (runningJobs.value[i].id === data.id) { -// isNew = false -// runningJobs.value[i] = data -// break -// } -// } -// if (isNew) { -// runningJobs.value.push(data) -// } -// } -// } -// } -// }); -// -// _socket.addEventListener('close', () => { -// ElMessage.error("Websocket 已经断开,正在重新连接服务器") -// connect() -// }); -// } +const rewritePrompt = () => { + loading.value = true + httpPost("/api/prompt/rewrite", {"prompt": params.value.prompt}).then(res => { + params.value.prompt = res.data + loading.value = false + }).catch(e => { + ElMessage.error("翻译失败:" + e.message) + }) +} const translatePrompt = () => { loading.value = true httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => { params.value.prompt = res.data loading.value = false - }).then(e => { + }).catch(e => { ElMessage.error("翻译失败:" + e.message) }) } @@ -584,22 +544,10 @@ const translatePrompt = () => { onMounted(() => { checkSession().then(user => { imgCalls.value = user['img_calls'] - // 获取运行中的任务 - httpGet(`/api/mj/jobs?status=0&user_id=${user['id']}`).then(res => { - runningJobs.value = res.data - }).catch(e => { - ElMessage.error("获取任务失败:" + e.message) - }) - // 获取已完成的任务 - httpGet(`/api/mj/jobs?status=1&user_id=${user['id']}`).then(res => { - finishedJobs.value = res.data - }).catch(e => { - ElMessage.error("获取任务失败:" + e.message) - }) + fetchRunningJobs(user.id) + fetchFinishJobs(user.id) - // 连接 socket - connect(); }).catch(() => { router.push('/login') }); @@ -614,6 +562,41 @@ onMounted(() => { }) }) +// 获取运行中的任务 +const fetchRunningJobs = (userId) => { + httpGet(`/api/mj/jobs?status=0&user_id=${userId}`).then(res => { + const jobs = res.data + const _jobs = [] + for (let i = 0; i < jobs.length; i++) { + if (jobs[i].progress === -1) { + ElNotification({ + title: '任务执行失败', + message: "任务ID:" + jobs[i]['task_id'], + type: 'error', + }) + continue + } + _jobs.push(jobs[i]) + } + runningJobs.value = _jobs + + setTimeout(() => fetchRunningJobs(userId), 10000) + + }).catch(e => { + ElMessage.error("获取任务失败:" + e.message) + }) +} + +const fetchFinishJobs = (userId) => { + // 获取已完成的任务 + httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => { + finishedJobs.value = res.data + setTimeout(() => fetchFinishJobs(userId), 10000) + }).catch(e => { + ElMessage.error("获取任务失败:" + e.message) + }) +} + // 切换图片比例 const changeRate = (item) => { params.value.rate = item.value @@ -676,7 +659,6 @@ const variation = (index, item) => { const send = (url, index, item) => { httpPost(url, { index: index, - src: "img", message_id: item.message_id, message_hash: item.hash, session_id: getSessionId(), diff --git a/web/src/views/Login.vue b/web/src/views/Login.vue index 3ac9a46f..a405f5aa 100644 --- a/web/src/views/Login.vue +++ b/web/src/views/Login.vue @@ -101,7 +101,7 @@ const login = function () { httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => { setUserToken(res.data) - if (prevRoute.path === '') { + if (prevRoute.path === '' || prevRoute.path === '/register') { if (isMobile()) { router.push('/mobile') } else { diff --git a/web/src/views/Register.vue b/web/src/views/Register.vue index 01c07f7e..2ed3d744 100644 --- a/web/src/views/Register.vue +++ b/web/src/views/Register.vue @@ -127,14 +127,6 @@ import {isMobile} from "@/utils/libs"; import {setUserToken} from "@/store/session"; import {checkSession} from "@/action/session"; -checkSession().then(() => { - if (isMobile()) { - router.push('/mobile') - } else { - router.push('/chat') - } -}).catch(() => { -}) const router = useRouter(); const title = ref('ChatGPT-PLUS 用户注册'); const formData = ref({