refactor: add midjourney pool implementation, add translate prompt for mj drawing

This commit is contained in:
RockYang 2023-12-13 16:38:27 +08:00
parent 96816c12ca
commit a398e7a550
16 changed files with 226 additions and 272 deletions

View File

@ -33,7 +33,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload", BasePath: "./static/upload",
}, },
}, },
MjConfigs: types.MidJourneyConfig{Enabled: false},
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"}, SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
WeChatBot: false, WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},

View File

@ -6,7 +6,6 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/handler" "chatplus/handler"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"chatplus/service/mj"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
@ -32,16 +31,14 @@ var logger = logger2.GetLogger()
type ChatHandler struct { type ChatHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB db *gorm.DB
redis *redis.Client redis *redis.Client
mjService *mj.Service
} }
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, service *mj.Service) *ChatHandler { func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client) *ChatHandler {
h := ChatHandler{ h := ChatHandler{
db: db, db: db,
redis: redis, redis: redis,
mjService: service,
} }
h.App = app h.App = app
return &h return &h

View File

@ -3,6 +3,7 @@ package handler
import ( import (
"chatplus/core" "chatplus/core"
"chatplus/core/types" "chatplus/core/types"
"chatplus/service"
"chatplus/service/mj" "chatplus/service/mj"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
@ -11,7 +12,6 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
"strings" "strings"
"time" "time"
@ -19,26 +19,22 @@ import (
type MidJourneyHandler struct { type MidJourneyHandler struct {
BaseHandler BaseHandler
redis *redis.Client
db *gorm.DB db *gorm.DB
mjService *mj.Service pool *mj.ServicePool
snowflake *service.Snowflake
} }
func NewMidJourneyHandler( func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool) *MidJourneyHandler {
app *core.AppServer,
client *redis.Client,
db *gorm.DB,
mjService *mj.Service) *MidJourneyHandler {
h := MidJourneyHandler{ h := MidJourneyHandler{
redis: client,
db: db, db: db,
mjService: mjService, snowflake: snowflake,
pool: pool,
} }
h.App = app h.App = app
return &h return &h
} }
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool { func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db) user, err := utils.GetLoginUser(c, h.db)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
@ -50,17 +46,17 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
return false
}
return true return true
} }
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) { func (h *MidJourneyHandler) Image(c *gin.Context) {
if !h.App.Config.MjConfigs[0].Enabled {
resp.ERROR(c, "MidJourney service is disabled")
return
}
var data struct { var data struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
@ -80,7 +76,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if !h.checkLimits(c) { if !h.preCheck(c) {
return return
} }
@ -121,9 +117,16 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
// generate task id
taskId, err := h.snowflake.Next(true)
if err != nil {
resp.ERROR(c, "error with generate task id: "+err.Error())
return
}
job := model.MidJourneyJob{ job := model.MidJourneyJob{
Type: types.TaskImage.String(), Type: types.TaskImage.String(),
UserId: userId, UserId: userId,
TaskId: taskId,
Progress: 0, Progress: 0,
Prompt: prompt, Prompt: prompt,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@ -133,24 +136,13 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return return
} }
h.mjService.PushTask(types.MjTask{ h.pool.PushTask(types.MjTask{
Id: int(job.Id), Id: int(job.Id),
SessionId: data.SessionId, SessionId: data.SessionId,
Src: types.TaskSrcImg,
Type: types.TaskImage, Type: types.TaskImage,
Prompt: prompt, Prompt: fmt.Sprintf("%s %s", taskId, prompt),
UserId: userId, UserId: userId,
}) })
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@ -174,65 +166,23 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
return return
} }
if !h.checkLimits(c) { if !h.preCheck(c) {
return return
} }
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
jobId := 0 jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := types.TaskSrc(data.Src) h.pool.PushTask(types.MjTask{
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
UserId: userId,
Hash: data.MessageHash,
Progress: 0,
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error == nil {
jobId = int(job.Id)
} else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
}
h.mjService.PushTask(types.MjTask{
Id: jobId, Id: jobId,
SessionId: data.SessionId, SessionId: data.SessionId,
Src: src,
Type: types.TaskUpscale, Type: types.TaskUpscale,
Prompt: data.Prompt, Prompt: data.Prompt,
UserId: userId, UserId: userId,
RoleId: data.RoleId,
Icon: data.Icon,
ChatId: data.ChatId,
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
}) })
if src == types.TaskSrcChat {
wsClient := h.App.ChatClients.Get(data.SessionId)
if wsClient != nil {
content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
utils.ReplyMessage(wsClient, content)
if h.mjService.ChatClients.Get(data.SessionId) == nil {
h.mjService.ChatClients.Put(data.SessionId, wsClient)
}
}
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@ -244,67 +194,23 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
return return
} }
if !h.checkLimits(c) { if !h.preCheck(c) {
return return
} }
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
jobId := 0 jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := types.TaskSrc(data.Src) h.pool.PushTask(types.MjTask{
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
UserId: userId,
ImgURL: "",
Hash: data.MessageHash,
Progress: 0,
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error == nil {
jobId = int(job.Id)
} else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
}
h.mjService.PushTask(types.MjTask{
Id: jobId, Id: jobId,
SessionId: data.SessionId, SessionId: data.SessionId,
Src: src,
Type: types.TaskVariation, Type: types.TaskVariation,
Prompt: data.Prompt, Prompt: data.Prompt,
UserId: userId, UserId: userId,
RoleId: data.RoleId,
Icon: data.Icon,
ChatId: data.ChatId,
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
}) })
if src == types.TaskSrcChat {
// 从聊天窗口发送的请求,记录客户端信息
wsClient := h.mjService.ChatClients.Get(data.SessionId)
if wsClient != nil {
content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
utils.ReplyMessage(wsClient, content)
if h.mjService.Clients.Get(data.SessionId) == nil {
h.mjService.Clients.Put(data.SessionId, wsClient)
}
}
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@ -343,19 +249,27 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
if err != nil { if err != nil {
continue continue
} }
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
}
if item.Progress < 100 { if item.Progress < 100 {
// 10 分钟还没完成的任务直接删除 // 10 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*10 { if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
h.db.Delete(&item) h.db.Delete(&item)
continue continue
} }
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) // 正在运行中任务使用代理访问图片
if item.ImgURL == "" && item.OrgURL != "" {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil { if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
} }
} }
} }
jobs = append(jobs, job) jobs = append(jobs, job)
} }
resp.SUCCESS(c, jobs) resp.SUCCESS(c, jobs)

View File

@ -176,7 +176,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
return return
} }
orderNo, err := h.snowflake.Next() orderNo, err := h.snowflake.Next(false)
if err != nil { if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error()) resp.ERROR(c, "error with generate trade no: "+err.Error())
return return

View File

@ -13,7 +13,8 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
const translatePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
type PromptHandler struct { type PromptHandler struct {
BaseHandler BaseHandler
@ -47,6 +48,25 @@ type apiErrRes struct {
} `json:"error"` } `json:"error"`
} }
// Rewrite translate and rewrite prompt with ChatGPT
func (h *PromptHandler) Rewrite(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := h.request(data.Prompt, rewritePromptTemplate)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) Translate(c *gin.Context) { func (h *PromptHandler) Translate(c *gin.Context) {
var data struct { var data struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
@ -55,18 +75,28 @@ func (h *PromptHandler) Translate(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
content, err := h.request(data.Prompt, translatePromptTemplate)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
// 获取 OpenAI 的 API KEY // 获取 OpenAI 的 API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey) res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "找不到可用 OpenAI API KEY") return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
return
} }
messages := make([]interface{}, 1) messages := make([]interface{}, 1)
messages[0] = types.Message{ messages[0] = types.Message{
Role: "user", Role: "user",
Content: fmt.Sprintf(translatePromptTemplate, data.Prompt), Content: fmt.Sprintf(promptTemplate, prompt),
} }
var response apiRes var response apiRes
@ -83,9 +113,8 @@ func (h *PromptHandler) Translate(c *gin.Context) {
SetErrorResult(&errRes). SetErrorResult(&errRes).
SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL) SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)) return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
return
} }
resp.SUCCESS(c, response.Choices[0].Message.Content) return response.Choices[0].Message.Content, nil
} }

View File

@ -141,7 +141,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
h.service.PushTask(types.SdTask{ h.service.PushTask(types.SdTask{
Id: int(job.Id), Id: int(job.Id),
SessionId: data.SessionId, SessionId: data.SessionId,
Src: types.TaskSrcImg,
Type: types.TaskImage, Type: types.TaskImage,
Prompt: data.Prompt, Prompt: data.Prompt,
Params: params, Params: params,

View File

@ -163,8 +163,9 @@ func main() {
} }
}), }),
// MidJourney 机器人 // MidJourney service pool
fx.Provide(mj.NewBot), fx.Provide(mj.NewServicePool),
// Stable Diffusion 机器人 // Stable Diffusion 机器人
fx.Provide(sd.NewService), fx.Provide(sd.NewService),
fx.Invoke(func(config *types.AppConfig, service *sd.Service) { fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
@ -341,6 +342,7 @@ func main() {
fx.Provide(handler.NewPromptHandler), fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) { fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt/") group := s.Engine.Group("/api/prompt/")
group.POST("rewrite", h.Rewrite)
group.POST("translate", h.Translate) group.POST("translate", h.Translate)
}), }),

View File

@ -19,17 +19,18 @@ var logger = logger2.GetLogger()
type Bot struct { type Bot struct {
config *types.MidJourneyConfig config *types.MidJourneyConfig
bot *discordgo.Session bot *discordgo.Session
name string
service *Service service *Service
} }
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) { func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken) discord, err := discordgo.New("Bot " + config.BotToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if config.ProxyURL != "" { if proxy != "" {
proxy, _ := url.Parse(config.ProxyURL) proxy, _ := url.Parse(proxy)
discord.Client = &http.Client{ discord.Client = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyURL(proxy), Proxy: http.ProxyURL(proxy),
@ -41,8 +42,9 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
} }
return &Bot{ return &Bot{
config: &config.MjConfigs, config: config,
bot: discord, bot: discord,
name: name,
service: service, service: service,
}, nil }, nil
} }
@ -52,13 +54,13 @@ func (b *Bot) Run() error {
b.bot.AddHandler(b.messageCreate) b.bot.AddHandler(b.messageCreate)
b.bot.AddHandler(b.messageUpdate) b.bot.AddHandler(b.messageUpdate)
logger.Info("Starting MidJourney Bot...") logger.Infof("Starting MidJourney %s", b.name)
err := b.bot.Open() err := b.bot.Open()
if err != nil { if err != nil {
logger.Error("Error opening Discord connection:", err) logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
return err return err
} }
logger.Info("Starting MidJourney Bot successfully!") logger.Infof("Starting MidJourney %s successfully!", b.name)
return nil return nil
} }

View File

@ -2,7 +2,6 @@ package mj
import ( import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/utils"
"fmt" "fmt"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"time" "time"
@ -18,9 +17,10 @@ type Client struct {
func NewClient(config *types.MidJourneyConfig, proxy string) *Client { func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second) client := req.C().SetTimeout(10 * time.Second)
// set proxy URL // set proxy URL
if utils.IsEmptyValue(proxy) { if proxy != "" {
client.SetProxyURL(proxy) client.SetProxyURL(proxy)
} }
logger.Info(proxy)
return &Client{client: client, config: config} return &Client{client: client, config: config}
} }

View File

@ -4,35 +4,63 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store" "chatplus/store"
"fmt"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
) )
// ServicePool Mj service pool // ServicePool Mj service pool
type ServicePool struct { type ServicePool struct {
services []Service services []*Service
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
} }
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
// create mj client and service // create mj client and service
for _, config := range appConfig.MjConfigs { for k, config := range appConfig.MjConfigs {
if config.Enabled == false { if config.Enabled == false {
continue continue
} }
// create mj client // create mj client
client := NewClient(&config, appConfig.ProxyURL) client := NewClient(&config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service // create mj service
service := NewService() service := NewService(name, queue, 4, 600, db, client, manager, appConfig)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
if err != nil {
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
go func() {
service.Run()
}()
services = append(services, service)
} }
return &ServicePool{ return &ServicePool{
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), taskQueue: queue,
services: services,
} }
} }
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) { func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task) logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task) p.taskQueue.RPush(task)
} }
// HasAvailableService check if has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@ -27,16 +27,17 @@ type Service struct {
snowflake *service.Snowflake snowflake *service.Snowflake
} }
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service { func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
return &Service{ return &Service{
name: name, name: name,
db: db, db: db,
taskQueue: queue, taskQueue: queue,
client: client, client: client,
uploadManager: manager, uploadManager: manager,
taskTimeout: timeout, taskTimeout: timeout,
proxyURL: config.ProxyURL, maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time, 0), proxyURL: config.ProxyURL,
taskStartTimes: make(map[int]time.Time, 0),
} }
} }
@ -58,7 +59,7 @@ func (s *Service) Run() {
continue continue
} }
logger.Infof("handle a new MidJourney task: %+v", task) logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type { switch task.Type {
case types.TaskImage: case types.TaskImage:
err = s.client.Imagine(task.Prompt) err = s.client.Imagine(task.Prompt)
@ -92,11 +93,14 @@ func (s *Service) canHandleTask() bool {
return handledNum < s.maxHandleTaskNum return handledNum < s.maxHandleTaskNum
} }
// remove the expired tasks
func (s *Service) checkTasks() { func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes { for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout { if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k) delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1) atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
} }
} }
} }
@ -121,15 +125,17 @@ func (s *Service) Notify(data CBReq) {
job.Progress = data.Progress job.Progress = data.Progress
job.Prompt = data.Prompt job.Prompt = data.Prompt
job.Hash = data.Image.Hash job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL // save origin image job.OrgURL = data.Image.URL
// upload image // upload image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) if data.Status == Finished {
if err != nil { imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
logger.Error("error with download img: ", err.Error()) if err != nil {
return logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imgURL
} }
job.ImgURL = imgURL
res = s.db.Updates(&job) res = s.db.Updates(&job)
if res.Error != nil { if res.Error != nil {

View File

@ -23,7 +23,7 @@ func NewSnowflake() *Snowflake {
} }
// Next 生成一个新的唯一ID // Next 生成一个新的唯一ID
func (s *Snowflake) Next() (string, error) { func (s *Snowflake) Next(raw bool) (string, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -43,6 +43,9 @@ func (s *Snowflake) Next() (string, error) {
s.lastTimestamp = timestamp s.lastTimestamp = timestamp
id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence) id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence)
if raw {
return fmt.Sprintf("%d", id), nil
}
now := time.Now() now := time.Now()
return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil
} }

View File

@ -6,6 +6,7 @@ type MidJourneyJob struct {
Id uint `json:"id"` Id uint `json:"id"`
Type string `json:"type"` Type string `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
TaskId string `json:"task_id"`
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"` ReferenceId string `json:"reference_id"`
ImgURL string `json:"img_url"` ImgURL string `json:"img_url"`

View File

@ -244,12 +244,30 @@
</el-icon> </el-icon>
</el-tooltip> </el-tooltip>
</div> </div>
<el-button type="success" @click="translatePrompt"> <div>
<el-icon style="margin-right: 6px;font-size: 18px;"> <el-button type="primary" @click="translatePrompt">
<Refresh/> <el-icon style="margin-right: 6px;font-size: 18px;">
</el-icon> <Refresh/>
翻译 </el-icon>
</el-button> 翻译
</el-button>
<el-tooltip
class="box-item"
effect="light"
raw-content
content="使用 AI 翻译并重写提示词,<br/>增加更多细节,风格等描述"
placement="top-end"
>
<el-button type="success" @click="rewritePrompt">
<el-icon style="margin-right: 6px;font-size: 18px;">
<Refresh/>
</el-icon>
翻译并重写
</el-button>
</el-tooltip>
</div>
</div> </div>
</div> </div>
@ -432,8 +450,7 @@ import ItemList from "@/components/ItemList.vue";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {checkSession} from "@/action/session"; import {checkSession} from "@/action/session";
import {useRouter} from "vue-router"; import {useRouter} from "vue-router";
import {getSessionId, getUserToken} from "@/store/session"; import {getSessionId} from "@/store/session";
import {removeArrayItem} from "@/utils/libs";
const listBoxHeight = ref(window.innerHeight - 40) const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150) const mjBoxHeight = ref(window.innerHeight - 150)
@ -504,79 +521,22 @@ const socket = ref(null)
const imgCalls = ref(0) const imgCalls = ref(0)
const loading = ref(false) const loading = ref(false)
// const connect = () => { const rewritePrompt = () => {
// let host = process.env.VUE_APP_WS_HOST loading.value = true
// if (host === '') { httpPost("/api/prompt/rewrite", {"prompt": params.value.prompt}).then(res => {
// if (location.protocol === 'https:') { params.value.prompt = res.data
// host = 'wss://' + location.host; loading.value = false
// } else { }).catch(e => {
// host = 'ws://' + location.host; ElMessage.error("翻译失败:" + e.message)
// } })
// } }
// const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
// _socket.addEventListener('open', () => {
// socket.value = _socket;
// });
//
// _socket.addEventListener('message', event => {
// if (event.data instanceof Blob) {
// const reader = new FileReader();
// reader.readAsText(event.data, "UTF-8");
// reader.onload = () => {
// const data = JSON.parse(String(reader.result));
// let isNew = true
// if (data.progress === 100) {
// for (let i = 0; i < finishedJobs.value.length; i++) {
// if (finishedJobs.value[i].id === data.id) {
// isNew = false
// break
// }
// }
// for (let i = 0; i < runningJobs.value.length; i++) {
// if (runningJobs.value[i].id === data.id) {
// runningJobs.value.splice(i, 1)
// break
// }
// }
// if (isNew) {
// finishedJobs.value.unshift(data)
// }
// } else if (data.progress === -1) { //
// ElNotification({
// title: '',
// message: "" + data['prompt'],
// type: 'error',
// })
// runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
//
// } else {
// for (let i = 0; i < runningJobs.value.length; i++) {
// if (runningJobs.value[i].id === data.id) {
// isNew = false
// runningJobs.value[i] = data
// break
// }
// }
// if (isNew) {
// runningJobs.value.push(data)
// }
// }
// }
// }
// });
//
// _socket.addEventListener('close', () => {
// ElMessage.error("Websocket ")
// connect()
// });
// }
const translatePrompt = () => { const translatePrompt = () => {
loading.value = true loading.value = true
httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => { httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => {
params.value.prompt = res.data params.value.prompt = res.data
loading.value = false loading.value = false
}).then(e => { }).catch(e => {
ElMessage.error("翻译失败:" + e.message) ElMessage.error("翻译失败:" + e.message)
}) })
} }
@ -584,22 +544,10 @@ const translatePrompt = () => {
onMounted(() => { onMounted(() => {
checkSession().then(user => { checkSession().then(user => {
imgCalls.value = user['img_calls'] imgCalls.value = user['img_calls']
//
httpGet(`/api/mj/jobs?status=0&user_id=${user['id']}`).then(res => {
runningJobs.value = res.data
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
// fetchRunningJobs(user.id)
httpGet(`/api/mj/jobs?status=1&user_id=${user['id']}`).then(res => { fetchFinishJobs(user.id)
finishedJobs.value = res.data
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
// socket
connect();
}).catch(() => { }).catch(() => {
router.push('/login') router.push('/login')
}); });
@ -614,6 +562,41 @@ onMounted(() => {
}) })
}) })
//
const fetchRunningJobs = (userId) => {
httpGet(`/api/mj/jobs?status=0&user_id=${userId}`).then(res => {
const jobs = res.data
const _jobs = []
for (let i = 0; i < jobs.length; i++) {
if (jobs[i].progress === -1) {
ElNotification({
title: '任务执行失败',
message: "任务ID" + jobs[i]['task_id'],
type: 'error',
})
continue
}
_jobs.push(jobs[i])
}
runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 10000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
const fetchFinishJobs = (userId) => {
//
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 10000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
// //
const changeRate = (item) => { const changeRate = (item) => {
params.value.rate = item.value params.value.rate = item.value
@ -676,7 +659,6 @@ const variation = (index, item) => {
const send = (url, index, item) => { const send = (url, index, item) => {
httpPost(url, { httpPost(url, {
index: index, index: index,
src: "img",
message_id: item.message_id, message_id: item.message_id,
message_hash: item.hash, message_hash: item.hash,
session_id: getSessionId(), session_id: getSessionId(),

View File

@ -101,7 +101,7 @@ const login = function () {
httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => { httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => {
setUserToken(res.data) setUserToken(res.data)
if (prevRoute.path === '') { if (prevRoute.path === '' || prevRoute.path === '/register') {
if (isMobile()) { if (isMobile()) {
router.push('/mobile') router.push('/mobile')
} else { } else {

View File

@ -127,14 +127,6 @@ import {isMobile} from "@/utils/libs";
import {setUserToken} from "@/store/session"; import {setUserToken} from "@/store/session";
import {checkSession} from "@/action/session"; import {checkSession} from "@/action/session";
checkSession().then(() => {
if (isMobile()) {
router.push('/mobile')
} else {
router.push('/chat')
}
}).catch(() => {
})
const router = useRouter(); const router = useRouter();
const title = ref('ChatGPT-PLUS 用户注册'); const title = ref('ChatGPT-PLUS 用户注册');
const formData = ref({ const formData = ref({