Compare commits

..

5 Commits

Author SHA1 Message Date
RockYang
8f057ca9d1 refactor: refactor stable diffusion service, add service pool support 2023-12-14 16:48:54 +08:00
RockYang
4a56621ec3 chore: add sub dir support for OSS 2023-12-13 17:02:49 +08:00
RockYang
a398e7a550 refactor: add midjourney pool implementation, add translate prompt for mj drawing 2023-12-13 16:38:27 +08:00
RockYang
96816c12ca fix: fixed bug for aliyun OSS img url 2023-12-13 09:49:55 +08:00
RockYang
9984926f69 refactor mj service, add mj service pool support 2023-12-12 18:33:24 +08:00
31 changed files with 613 additions and 752 deletions

View File

@@ -169,9 +169,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
var tokenString string var tokenString string
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader) tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" || } else if c.Request.URL.Path == "/api/chat/new" {
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/sd/client" {
tokenString = c.Query("token") tokenString = c.Query("token")
} else { } else {
tokenString = c.GetHeader(types.UserAuthHeader) tokenString = c.GetHeader(types.UserAuthHeader)

View File

@@ -33,8 +33,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload", BasePath: "./static/upload",
}, },
}, },
MjConfig: types.MidJourneyConfig{Enabled: false},
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
WeChatBot: false, WeChatBot: false,
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
} }

View File

@@ -18,9 +18,9 @@ type AppConfig struct {
AesEncryptKey string AesEncryptKey string
SmsConfig AliYunSmsConfig // AliYun send message service config SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config OSS OSSConfig // OSS config
MjConfig MidJourneyConfig // mj 绘画配置 MjConfigs []MidJourneyConfig // mj AI draw service pool
WeChatBot bool // 是否启用微信机器人 WeChatBot bool // 是否启用微信机器人
SdConfig StableDiffusionConfig // sd 绘画配置 SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig XXLConfig XXLConfig
AlipayConfig AlipayConfig AlipayConfig AlipayConfig

View File

@@ -12,6 +12,7 @@ type MiniOssConfig struct {
AccessKey string AccessKey string
AccessSecret string AccessSecret string
Bucket string Bucket string
SubDir string
UseSSL bool UseSSL bool
Domain string Domain string
} }
@@ -21,6 +22,7 @@ type QiNiuOssConfig struct {
AccessKey string AccessKey string
AccessSecret string AccessSecret string
Bucket string Bucket string
SubDir string
Domain string Domain string
} }
@@ -29,6 +31,7 @@ type AliYunOssConfig struct {
AccessKey string AccessKey string
AccessSecret string AccessSecret string
Bucket string Bucket string
SubDir string
Domain string Domain string
} }

View File

@@ -11,28 +11,15 @@ const (
TaskImage = TaskType("image") TaskImage = TaskType("image")
TaskUpscale = TaskType("upscale") TaskUpscale = TaskType("upscale")
TaskVariation = TaskType("variation") TaskVariation = TaskType("variation")
TaskTxt2Img = TaskType("text2img")
)
// TaskSrc 任务来源
type TaskSrc string
const (
TaskSrcChat = TaskSrc("chat") // 来自聊天页面
TaskSrcImg = TaskSrc("img") // 专业绘画页面
) )
// MjTask MidJourney 任务 // MjTask MidJourney 任务
type MjTask struct { type MjTask struct {
Id int `json:"id"` Id int `json:"id"`
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
ChatId string `json:"chat_id,omitempty"`
RoleId int `json:"role_id,omitempty"`
Icon string `json:"icon,omitempty"`
Index int `json:"index,omitempty"` Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"` MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"` MessageHash string `json:"message_hash,omitempty"`
@@ -42,7 +29,6 @@ type MjTask struct {
type SdTask struct { type SdTask struct {
Id int `json:"id"` // job 数据库ID Id int `json:"id"` // job 数据库ID
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`

View File

@@ -6,7 +6,6 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/handler" "chatplus/handler"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"chatplus/service/mj"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
@@ -34,14 +33,12 @@ type ChatHandler struct {
handler.BaseHandler handler.BaseHandler
db *gorm.DB db *gorm.DB
redis *redis.Client redis *redis.Client
mjService *mj.Service
} }
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, service *mj.Service) *ChatHandler { func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client) *ChatHandler {
h := ChatHandler{ h := ChatHandler{
db: db, db: db,
redis: redis, redis: redis,
mjService: service,
} }
h.App = app h.App = app
return &h return &h

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"chatplus/core" "chatplus/core"
"chatplus/core/types" "chatplus/core/types"
"chatplus/service"
"chatplus/service/mj" "chatplus/service/mj"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
@@ -11,50 +12,29 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
"strings" "strings"
"time" "time"
) )
type MidJourneyHandler struct { type MidJourneyHandler struct {
BaseHandler BaseHandler
redis *redis.Client
db *gorm.DB db *gorm.DB
mjService *mj.Service pool *mj.ServicePool
snowflake *service.Snowflake
} }
func NewMidJourneyHandler( func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool) *MidJourneyHandler {
app *core.AppServer,
client *redis.Client,
db *gorm.DB,
mjService *mj.Service) *MidJourneyHandler {
h := MidJourneyHandler{ h := MidJourneyHandler{
redis: client,
db: db, db: db,
mjService: mjService, snowflake: snowflake,
pool: pool,
} }
h.App = app h.App = app
return &h return &h
} }
// Client WebSocket 客户端,用于通知任务状态变更 func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
h.mjService.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db) user, err := utils.GetLoginUser(c, h.db)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
@@ -66,17 +46,17 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
return false
}
return true return true
} }
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) { func (h *MidJourneyHandler) Image(c *gin.Context) {
if !h.App.Config.MjConfig.Enabled {
resp.ERROR(c, "MidJourney service is disabled")
return
}
var data struct { var data struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
@@ -96,7 +76,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
if !h.checkLimits(c) { if !h.preCheck(c) {
return return
} }
@@ -137,9 +117,16 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
// generate task id
taskId, err := h.snowflake.Next(true)
if err != nil {
resp.ERROR(c, "error with generate task id: "+err.Error())
return
}
job := model.MidJourneyJob{ job := model.MidJourneyJob{
Type: types.TaskImage.String(), Type: types.TaskImage.String(),
UserId: userId, UserId: userId,
TaskId: taskId,
Progress: 0, Progress: 0,
Prompt: prompt, Prompt: prompt,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@@ -149,24 +136,13 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return return
} }
h.mjService.PushTask(types.MjTask{ h.pool.PushTask(types.MjTask{
Id: int(job.Id), Id: int(job.Id),
SessionId: data.SessionId, SessionId: data.SessionId,
Src: types.TaskSrcImg,
Type: types.TaskImage, Type: types.TaskImage,
Prompt: prompt, Prompt: fmt.Sprintf("%s %s", taskId, prompt),
UserId: userId, UserId: userId,
}) })
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -190,65 +166,23 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
return return
} }
if !h.checkLimits(c) { if !h.preCheck(c) {
return return
} }
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
jobId := 0 jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := types.TaskSrc(data.Src) h.pool.PushTask(types.MjTask{
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
UserId: userId,
Hash: data.MessageHash,
Progress: 0,
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error == nil {
jobId = int(job.Id)
} else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
}
h.mjService.PushTask(types.MjTask{
Id: jobId, Id: jobId,
SessionId: data.SessionId, SessionId: data.SessionId,
Src: src,
Type: types.TaskUpscale, Type: types.TaskUpscale,
Prompt: data.Prompt, Prompt: data.Prompt,
UserId: userId, UserId: userId,
RoleId: data.RoleId,
Icon: data.Icon,
ChatId: data.ChatId,
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
}) })
if src == types.TaskSrcChat {
wsClient := h.App.ChatClients.Get(data.SessionId)
if wsClient != nil {
content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
utils.ReplyMessage(wsClient, content)
if h.mjService.ChatClients.Get(data.SessionId) == nil {
h.mjService.ChatClients.Put(data.SessionId, wsClient)
}
}
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -260,67 +194,23 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
return return
} }
if !h.checkLimits(c) { if !h.preCheck(c) {
return return
} }
idValue, _ := c.Get(types.LoginUserID) idValue, _ := c.Get(types.LoginUserID)
jobId := 0 jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0) userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := types.TaskSrc(data.Src) h.pool.PushTask(types.MjTask{
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
UserId: userId,
ImgURL: "",
Hash: data.MessageHash,
Progress: 0,
Prompt: data.Prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error == nil {
jobId = int(job.Id)
} else {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
}
h.mjService.PushTask(types.MjTask{
Id: jobId, Id: jobId,
SessionId: data.SessionId, SessionId: data.SessionId,
Src: src,
Type: types.TaskVariation, Type: types.TaskVariation,
Prompt: data.Prompt, Prompt: data.Prompt,
UserId: userId, UserId: userId,
RoleId: data.RoleId,
Icon: data.Icon,
ChatId: data.ChatId,
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
}) })
if src == types.TaskSrcChat {
// 从聊天窗口发送的请求,记录客户端信息
wsClient := h.mjService.ChatClients.Get(data.SessionId)
if wsClient != nil {
content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
utils.ReplyMessage(wsClient, content)
if h.mjService.Clients.Get(data.SessionId) == nil {
h.mjService.Clients.Put(data.SessionId, wsClient)
}
}
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -359,19 +249,27 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
if err != nil { if err != nil {
continue continue
} }
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
}
if item.Progress < 100 { if item.Progress < 100 {
// 10 分钟还没完成的任务直接删除 // 10 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*10 { if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
h.db.Delete(&item) h.db.Delete(&item)
continue continue
} }
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) // 正在运行中任务使用代理访问图片
if item.ImgURL == "" && item.OrgURL != "" {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil { if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
} }
} }
} }
jobs = append(jobs, job) jobs = append(jobs, job)
} }
resp.SUCCESS(c, jobs) resp.SUCCESS(c, jobs)

View File

@@ -176,7 +176,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
return return
} }
orderNo, err := h.snowflake.Next() orderNo, err := h.snowflake.Next(false)
if err != nil { if err != nil {
resp.ERROR(c, "error with generate trade no: "+err.Error()) resp.ERROR(c, "error with generate trade no: "+err.Error())
return return

View File

@@ -13,7 +13,8 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
const translatePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
type PromptHandler struct { type PromptHandler struct {
BaseHandler BaseHandler
@@ -47,6 +48,25 @@ type apiErrRes struct {
} `json:"error"` } `json:"error"`
} }
// Rewrite translate and rewrite prompt with ChatGPT
func (h *PromptHandler) Rewrite(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := h.request(data.Prompt, rewritePromptTemplate)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) Translate(c *gin.Context) { func (h *PromptHandler) Translate(c *gin.Context) {
var data struct { var data struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
@@ -55,18 +75,28 @@ func (h *PromptHandler) Translate(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
content, err := h.request(data.Prompt, translatePromptTemplate)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
// 获取 OpenAI 的 API KEY // 获取 OpenAI 的 API KEY
var apiKey model.ApiKey var apiKey model.ApiKey
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey) res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
if res.Error != nil { if res.Error != nil {
resp.ERROR(c, "找不到可用 OpenAI API KEY") return "", fmt.Errorf("error with fetch OpenAI API KEY%v", res.Error)
return
} }
messages := make([]interface{}, 1) messages := make([]interface{}, 1)
messages[0] = types.Message{ messages[0] = types.Message{
Role: "user", Role: "user",
Content: fmt.Sprintf(translatePromptTemplate, data.Prompt), Content: fmt.Sprintf(promptTemplate, prompt),
} }
var response apiRes var response apiRes
@@ -83,9 +113,8 @@ func (h *PromptHandler) Translate(c *gin.Context) {
SetErrorResult(&errRes). SetErrorResult(&errRes).
SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL) SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL)
if err != nil || r.IsErrorState() { if err != nil || r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)) return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
return
} }
resp.SUCCESS(c, response.Choices[0].Message.Content) return response.Choices[0].Message.Content, nil
} }

View File

@@ -8,12 +8,11 @@ import (
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"encoding/base64"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
"time" "time"
) )
@@ -21,34 +20,18 @@ type SdJobHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
db *gorm.DB db *gorm.DB
service *sd.Service pool *sd.ServicePool
} }
func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler { func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool) *SdJobHandler {
h := SdJobHandler{ h := SdJobHandler{
redis: redisCli,
db: db, db: db,
service: service, pool: pool,
} }
h.App = app h.App = app
return &h return &h
} }
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
// 删除旧的连接
h.service.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool { func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db) user, err := utils.GetLoginUser(c, h.db)
if err != nil { if err != nil {
@@ -56,6 +39,11 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
return false
}
if user.ImgCalls <= 0 { if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
return false return false
@@ -67,11 +55,6 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) { func (h *SdJobHandler) Image(c *gin.Context) {
if !h.App.Config.SdConfig.Enabled {
resp.ERROR(c, "Stable Diffusion service is disabled")
return
}
if !h.checkLimits(c) { if !h.checkLimits(c) {
return return
} }
@@ -129,7 +112,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Params: utils.JsonEncode(params), Params: utils.JsonEncode(params),
Prompt: data.Prompt, Prompt: data.Prompt,
Progress: 0, Progress: 0,
Started: false,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
res := h.db.Create(&job) res := h.db.Create(&job)
@@ -138,24 +120,15 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return return
} }
h.service.PushTask(types.SdTask{ h.pool.PushTask(types.SdTask{
Id: int(job.Id), Id: int(job.Id),
SessionId: data.SessionId, SessionId: data.SessionId,
Src: types.TaskSrcImg,
Type: types.TaskImage, Type: types.TaskImage,
Prompt: data.Prompt, Prompt: data.Prompt,
Params: params, Params: params,
UserId: userId, UserId: userId,
}) })
var jobVo vo.SdJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.service.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
resp.SUCCESS(c) resp.SUCCESS(c)
} }
@@ -194,12 +167,22 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
if err != nil { if err != nil {
continue continue
} }
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
}
if item.Progress < 100 { if item.Progress < 100 {
// 30 分钟还没完成的任务直接删除 // 10 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*30 { if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
h.db.Delete(&item) h.db.Delete(&item)
continue continue
} }
// 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, "")
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
} }
jobs = append(jobs, job) jobs = append(jobs, job)
} }

View File

@@ -163,34 +163,11 @@ func main() {
} }
}), }),
// MidJourney 机器人 // MidJourney service pool
fx.Provide(mj.NewBot), fx.Provide(mj.NewServicePool),
fx.Provide(mj.NewClient),
fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
if config.MjConfig.Enabled {
err := bot.Run()
if err != nil {
log.Fatal("MidJourney 服务启动失败:", err)
}
}
}),
fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) {
if config.MjConfig.Enabled {
go func() {
mjService.Run()
}()
}
}),
// Stable Diffusion 机器人 // Stable Diffusion 机器人
fx.Provide(sd.NewService), fx.Provide(sd.NewServicePool),
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
if config.SdConfig.Enabled {
go func() {
service.Run()
}()
}
}),
fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay), fx.Provide(payment.NewHuPiPay),
@@ -256,13 +233,11 @@ func main() {
group.POST("upscale", h.Upscale) group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation) group.POST("variation", h.Variation)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.Any("client", h.Client)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd") group := s.Engine.Group("/api/sd")
group.POST("image", h.Image) group.POST("image", h.Image)
group.GET("jobs", h.JobList) group.GET("jobs", h.JobList)
group.Any("client", h.Client)
}), }),
// 管理后台控制器 // 管理后台控制器
@@ -360,6 +335,7 @@ func main() {
fx.Provide(handler.NewPromptHandler), fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) { fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt/") group := s.Engine.Group("/api/prompt/")
group.POST("rewrite", h.Rewrite)
group.POST("translate", h.Translate) group.POST("translate", h.Translate)
}), }),

View File

@@ -1,21 +1,21 @@
{ {
"data": [ "data": [
"task(s95jqt5jr8yppcp)", "task(owy5niy1sbbnlq0)",
"A beautiful Chinese girl in a garden", "A beautiful Chinese girl plays the guitar on the beach. She is dressed in a flowing dress that matches the colors of the sunset. With her eyes closed, she strums the guitar with passion and confidence, her fingers dancing gracefully on the strings. The painting employs a vibrant color palette, capturing the warmth of the setting sun blending with the serene hues of the ocean. The artist uses a combination of impressionistic and realistic brushstrokes to convey both the girl's delicate features and the dynamic movement of the waves. The rendering effect creates a dream-like atmosphere, as if the viewer is being transported to a magical realm where music and nature intertwine. The picture is bathed in a soft, golden light, casting a warm glow on the girl's face, illuminating her joy and connection to the music she creates.",
"", "",
[], [],
30, 30,
"Euler a", "DPM++ 3M SDE Karras",
1, 1,
1, 1,
7, 7,
512, 512,
512, 512,
true, false,
0.7, 0.7,
2, 2,
"Latent", "Latent",
10, 0,
0, 0,
0, 0,
"Use same checkpoint", "Use same checkpoint",
@@ -33,6 +33,9 @@
0, 0,
0, 0,
0, 0,
null,
null,
null,
false, false,
false, false,
"positive", "positive",
@@ -55,13 +58,22 @@
false, false,
false, false,
0, 0,
[ null,
], null,
false,
null,
null,
false,
null,
null,
false,
50,
[],
"", "",
"", "",
"" ""
], ],
"event_data": null, "event_data": null,
"fn_index": 95, "fn_index": 316,
"session_hash": "eqwumnt3rov" "session_hash": "ttr8efgt63g"
} }

View File

@@ -19,17 +19,18 @@ var logger = logger2.GetLogger()
type Bot struct { type Bot struct {
config *types.MidJourneyConfig config *types.MidJourneyConfig
bot *discordgo.Session bot *discordgo.Session
name string
service *Service service *Service
} }
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) { func NewBot(name string, proxy string, config *types.MidJourneyConfig, service *Service) (*Bot, error) {
discord, err := discordgo.New("Bot " + config.MjConfig.BotToken) discord, err := discordgo.New("Bot " + config.BotToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if config.ProxyURL != "" { if proxy != "" {
proxy, _ := url.Parse(config.ProxyURL) proxy, _ := url.Parse(proxy)
discord.Client = &http.Client{ discord.Client = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyURL(proxy), Proxy: http.ProxyURL(proxy),
@@ -41,8 +42,9 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
} }
return &Bot{ return &Bot{
config: &config.MjConfig, config: config,
bot: discord, bot: discord,
name: name,
service: service, service: service,
}, nil }, nil
} }
@@ -52,13 +54,13 @@ func (b *Bot) Run() error {
b.bot.AddHandler(b.messageCreate) b.bot.AddHandler(b.messageCreate)
b.bot.AddHandler(b.messageUpdate) b.bot.AddHandler(b.messageUpdate)
logger.Info("Starting MidJourney Bot...") logger.Infof("Starting MidJourney %s", b.name)
err := b.bot.Open() err := b.bot.Open()
if err != nil { if err != nil {
logger.Error("Error opening Discord connection:", err) logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err)
return err return err
} }
logger.Info("Starting MidJourney Bot successfully!") logger.Infof("Starting MidJourney %s successfully!", b.name)
return nil return nil
} }

View File

@@ -14,13 +14,14 @@ type Client struct {
config *types.MidJourneyConfig config *types.MidJourneyConfig
} }
func NewClient(config *types.AppConfig) *Client { func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
client := req.C().SetTimeout(10 * time.Second) client := req.C().SetTimeout(10 * time.Second)
// set proxy URL // set proxy URL
if config.ProxyURL != "" { if proxy != "" {
client.SetProxyURL(config.ProxyURL) client.SetProxyURL(proxy)
} }
return &Client{client: client, config: &config.MjConfig} logger.Info(proxy)
return &Client{client: client, config: config}
} }
func (c *Client) Imagine(prompt string) error { func (c *Client) Imagine(prompt string) error {

66
api/service/mj/pool.go Normal file
View File

@@ -0,0 +1,66 @@
package mj
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"fmt"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.MjConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(&config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, queue, 4, 600, db, client, manager, appConfig)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, &config, service)
if err != nil {
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: queue,
services: services,
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@@ -5,60 +5,59 @@ import (
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store" "chatplus/store"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"encoding/base64"
"fmt"
"github.com/go-redis/redis/v8"
"gorm.io/gorm" "gorm.io/gorm"
"strings"
"sync/atomic"
"time" "time"
) )
// MJ 绘画服务 // Service MJ 绘画服务
const RunningJobKey = "MidJourney_Running_Job"
type Service struct { type Service struct {
client *Client // MJ 客户端 name string // service name
client *Client // MJ client
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
proxyURL string proxyURL string
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
} }
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service { func NewService(name string, queue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
return &Service{ return &Service{
redis: redisCli, name: name,
db: db, db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), taskQueue: queue,
client: client, client: client,
uploadManager: manager, uploadManager: manager,
Clients: types.NewLMap[string, *types.WsClient](), taskTimeout: timeout,
ChatClients: types.NewLMap[string, *types.WsClient](), maxHandleTaskNum: maxTaskNum,
proxyURL: config.ProxyURL, proxyURL: config.ProxyURL,
taskStartTimes: make(map[int]time.Time, 0),
} }
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Info("Starting MidJourney job consumer.") logger.Infof("Starting MidJourney job consumer for %s", s.name)
ctx := context.Background()
for { for {
_, err := s.redis.Get(ctx, RunningJobKey).Result() s.checkTasks()
if err == nil { // 队列串行执行 if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} }
var task types.MjTask var task types.MjTask
err = s.taskQueue.LPop(&task) err := s.taskQueue.LPop(&task)
if err != nil { if err != nil {
logger.Errorf("taking task with error: %v", err) logger.Errorf("taking task with error: %v", err)
continue continue
} }
logger.Infof("Consuming Task: %+v", task)
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
switch task.Type { switch task.Type {
case types.TaskImage: case types.TaskImage:
err = s.client.Imagine(task.Prompt) err = s.client.Imagine(task.Prompt)
@@ -70,50 +69,43 @@ func (s *Service) Run() {
case types.TaskVariation: case types.TaskVariation:
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash) err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
} }
if err != nil { if err != nil {
logger.Error("绘画任务执行失败:", err) logger.Error("绘画任务执行失败:", err)
// 删除任务 // update the task progress
s.db.Delete(&model.MidJourneyJob{Id: uint(task.Id)}) s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
// 推送任务到前端 atomic.AddInt32(&s.handledTaskNum, -1)
client := s.Clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, vo.MidJourneyJob{
Type: task.Type.String(),
UserId: task.UserId,
MessageId: task.MessageId,
Progress: -1,
Prompt: task.Prompt,
})
}
continue continue
} }
// 更新任务的执行状态 // lock the task until the execute timeout
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) s.taskStartTimes[task.Id] = time.Now()
// 锁定任务执行通道直到任务超时5分钟 atomic.AddInt32(&s.handledTaskNum, 1)
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
} }
} }
func (s *Service) PushTask(task types.MjTask) { // check if current service instance can handle more task
logger.Infof("add a new MidJourney Task: %+v", task) func (s *Service) canHandleTask() bool {
s.taskQueue.RPush(task) handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
} }
func (s *Service) Notify(data CBReq) { func (s *Service) Notify(data CBReq) {
taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result() // extract the task ID
if err != nil { // 过期任务,丢弃 split := strings.Split(data.Prompt, " ")
logger.Warn("任务已过期:", err)
return
}
var task types.MjTask
err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err)
return
}
var job model.MidJourneyJob var job model.MidJourneyJob
res := s.db.Where("message_id = ?", data.MessageId).First(&job) res := s.db.Where("message_id = ?", data.MessageId).First(&job)
if res.Error == nil && data.Status == Finished { if res.Error == nil && data.Status == Finished {
@@ -121,9 +113,7 @@ func (s *Service) Notify(data CBReq) {
return return
} }
if task.Src == types.TaskSrcImg { // 绘画任务 res = s.db.Where("task_id = ?", split[0]).First(&job)
var job model.MidJourneyJob
res := s.db.Where("id = ?", task.Id).First(&job)
if res.Error != nil { if res.Error != nil {
logger.Warn("非法任务:", res.Error) logger.Warn("非法任务:", res.Error)
return return
@@ -133,125 +123,30 @@ func (s *Service) Notify(data CBReq) {
job.Progress = data.Progress job.Progress = data.Progress
job.Prompt = data.Prompt job.Prompt = data.Prompt
job.Hash = data.Image.Hash job.Hash = data.Image.Hash
job.OrgURL = data.Image.URL
// 任务完成,将最终的图片下载下来
if data.Progress == 100 {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imgURL
} else {
// 临时图片直接保存,访问的时候使用代理进行转发
job.ImgURL = data.Image.URL
}
res = s.db.Updates(&job) res = s.db.Updates(&job)
if res.Error != nil { if res.Error != nil {
logger.Error("error with update job: ", res.Error) logger.Error("error with update job: ", res.Error)
return return
} }
var jobVo vo.MidJourneyJob // upload image
err := utils.CopyObject(job, &jobVo)
if err == nil {
if data.Progress < 100 {
image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
if err == nil {
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
// 推送任务到前端
client := s.Clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
} else if task.Src == types.TaskSrcChat { // 聊天任务
wsClient := s.ChatClients.Get(task.SessionId)
if data.Status == Finished { if data.Status == Finished {
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
utils.ReplyMessage(wsClient, content)
}
// download image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true) imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil { if err != nil {
logger.Error("error with download image: ", err) logger.Error("error with download img: ", err.Error())
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
utils.ReplyMessage(wsClient, content)
}
return return
} }
tx := s.db.Begin()
data.Image.URL = imgURL
message := model.HistoryMessage{
UserId: uint(task.UserId),
ChatId: task.ChatId,
RoleId: uint(task.RoleId),
Type: types.MjMsg,
Icon: task.Icon,
Content: utils.JsonEncode(data),
Tokens: 0,
UseContext: false,
}
res = tx.Create(&message)
if res.Error != nil {
logger.Error("error with update database: ", err)
return
}
// save the job
job.UserId = task.UserId
job.Type = task.Type.String()
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Prompt = data.Prompt
job.ImgURL = imgURL job.ImgURL = imgURL
job.Progress = data.Progress s.db.Updates(&job)
job.Hash = data.Image.Hash
job.CreatedAt = time.Now()
res = tx.Create(&job)
if res.Error != nil {
logger.Error("error with update database: ", err)
tx.Rollback()
return
}
tx.Commit()
}
if wsClient == nil { // 客户端断线,则丢弃
logger.Errorf("Client is offline: %+v", data)
return
} }
if data.Status == Finished { if data.Status == Finished {
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) // update user's img calls
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// 本次绘画完毕,移除客户端 // release lock task
s.ChatClients.Delete(task.SessionId) atomic.AddInt32(&s.handledTaskNum, -1)
} else {
// 使用代理临时转发图片
if data.Image.URL != "" {
image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
if err == nil {
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
}
}
// 更新用户剩余绘图次数
// TODO: 放大图片是否需要消耗绘图次数?
if data.Status == Finished {
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// 解除任务锁定
s.redis.Del(context.Background(), RunningJobKey)
} }
} }

View File

@@ -32,6 +32,10 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
return nil, err return nil, err
} }
if config.SubDir == "" {
config.SubDir = "gpt"
}
return &AliYunOss{ return &AliYunOss{
config: config, config: config,
bucket: bucket, bucket: bucket,
@@ -54,14 +58,14 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
defer src.Close() defer src.Close()
fileExt := filepath.Ext(file.Filename) fileExt := filepath.Ext(file.Filename)
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt) objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件 // 上传文件
err = s.bucket.PutObject(objectKey, src) err = s.bucket.PutObject(objectKey, src)
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Domain, objectKey), nil return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
} }
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) { func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
@@ -80,18 +84,19 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
return "", fmt.Errorf("error with parse image URL: %v", err) return "", fmt.Errorf("error with parse image URL: %v", err)
} }
fileExt := filepath.Ext(parse.Path) fileExt := filepath.Ext(parse.Path)
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt) objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件字节数据 // 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData)) err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Domain, objectKey), nil return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
} }
func (s AliYunOss) Delete(fileURL string) error { func (s AliYunOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL) objectName := filepath.Base(fileURL)
return s.bucket.DeleteObject(objectName) key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.bucket.DeleteObject(key)
} }
var _ Uploader = AliYunOss{} var _ Uploader = AliYunOss{}

View File

@@ -29,6 +29,9 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
if err != nil { if err != nil {
return MiniOss{}, err return MiniOss{}, err
} }
if config.SubDir == "" {
config.SubDir = "gpt"
}
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
} }
@@ -48,7 +51,7 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
return "", fmt.Errorf("error with parse image URL: %v", err) return "", fmt.Errorf("error with parse image URL: %v", err)
} }
fileExt := filepath.Ext(parse.Path) fileExt := filepath.Ext(parse.Path)
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt) filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject( info, err := s.client.PutObject(
context.Background(), context.Background(),
s.config.Bucket, s.config.Bucket,
@@ -75,7 +78,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (string, error) {
defer fileReader.Close() defer fileReader.Close()
fileExt := filepath.Ext(file.Filename) fileExt := filepath.Ext(file.Filename)
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt) filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{ info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Content-Type"), ContentType: file.Header.Get("Content-Type"),
}) })
@@ -88,7 +91,8 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (string, error) {
func (s MiniOss) Delete(fileURL string) error { func (s MiniOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL) objectName := filepath.Base(fileURL)
return s.client.RemoveObject(context.Background(), s.config.Bucket, objectName, minio.RemoveObjectOptions{}) key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.client.RemoveObject(context.Background(), s.config.Bucket, key, minio.RemoveObjectOptions{})
} }
var _ Uploader = MiniOss{} var _ Uploader = MiniOss{}

View File

@@ -21,7 +21,6 @@ type QinNiuOss struct {
uploader *storage.FormUploader uploader *storage.FormUploader
manager *storage.BucketManager manager *storage.BucketManager
proxyURL string proxyURL string
dir string
} }
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss { func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
@@ -38,6 +37,9 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
putPolicy := storage.PutPolicy{ putPolicy := storage.PutPolicy{
Scope: config.Bucket, Scope: config.Bucket,
} }
if config.SubDir == "" {
config.SubDir = "gpt"
}
return QinNiuOss{ return QinNiuOss{
config: config, config: config,
mac: mac, mac: mac,
@@ -45,7 +47,6 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
uploader: formUploader, uploader: formUploader,
manager: storage.NewBucketManager(mac, &storeConfig), manager: storage.NewBucketManager(mac, &storeConfig),
proxyURL: appConfig.ProxyURL, proxyURL: appConfig.ProxyURL,
dir: "chatgpt-plus",
} }
} }
@@ -63,7 +64,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
defer src.Close() defer src.Close()
fileExt := filepath.Ext(file.Filename) fileExt := filepath.Ext(file.Filename)
key := fmt.Sprintf("%s/%d%s", s.dir, time.Now().UnixMicro(), fileExt) key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
// 上传文件 // 上传文件
ret := storage.PutRet{} ret := storage.PutRet{}
extra := storage.PutExtra{} extra := storage.PutExtra{}
@@ -91,7 +92,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
return "", fmt.Errorf("error with parse image URL: %v", err) return "", fmt.Errorf("error with parse image URL: %v", err)
} }
fileExt := filepath.Ext(parse.Path) fileExt := filepath.Ext(parse.Path)
key := fmt.Sprintf("%s/%d%s", s.dir, time.Now().UnixMicro(), fileExt) key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
ret := storage.PutRet{} ret := storage.PutRet{}
extra := storage.PutExtra{} extra := storage.PutExtra{}
// 上传文件字节数据 // 上传文件字节数据
@@ -104,7 +105,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
func (s QinNiuOss) Delete(fileURL string) error { func (s QinNiuOss) Delete(fileURL string) error {
objectName := filepath.Base(fileURL) objectName := filepath.Base(fileURL)
key := fmt.Sprintf("%s/%s", s.dir, objectName) key := fmt.Sprintf("%s/%s", s.config.SubDir, objectName)
return s.manager.Delete(s.config.Bucket, key) return s.manager.Delete(s.config.Bucket, key)
} }

View File

@@ -55,7 +55,7 @@ func (s *HuPiPayService) Sign(params map[string]string) string {
var data string var data string
keys := make([]string, 0, 0) keys := make([]string, 0, 0)
params["appid"] = s.appId params["appid"] = s.appId
for key, _ := range params { for key := range params {
keys = append(keys, key) keys = append(keys, key)
} }
sort.Strings(keys) sort.Strings(keys)

52
api/service/sd/pool.go Normal file
View File

@@ -0,0 +1,52 @@
package sd
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"fmt"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
queue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.SdConfigs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%d", k)
service := NewService(name, 4, 600, &config, queue, db, manager)
// run sd service
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: queue,
services: services,
}
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.SdTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}

View File

@@ -5,84 +5,96 @@ import (
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store" "chatplus/store"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
"io" "io"
"os" "os"
"strconv" "strconv"
"sync/atomic"
"time" "time"
) )
// SD 绘画服务 // SD 绘画服务
const RunningJobKey = "StableDiffusion_Running_Job"
type Service struct { type Service struct {
httpClient *req.Client httpClient *req.Client
config *types.StableDiffusionConfig config *types.StableDiffusionConfig
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB db *gorm.DB
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
Clients *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池 name string // service name
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
} }
func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service { func NewService(name string, maxTaskNum int32, timeout int64, config *types.StableDiffusionConfig, queue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service {
return &Service{ return &Service{
config: &config.SdConfig, name: name,
config: config,
httpClient: req.C(), httpClient: req.C(),
redis: redisCli, taskQueue: queue,
db: db, db: db,
uploadManager: manager, uploadManager: manager,
Clients: types.NewLMap[string, *types.WsClient](), taskTimeout: timeout,
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli), maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time),
} }
} }
func (s *Service) Run() { func (s *Service) Run() {
logger.Info("Starting StableDiffusion job consumer.")
ctx := context.Background()
for { for {
_, err := s.redis.Get(ctx, RunningJobKey).Result() s.checkTasks()
if err == nil { // 队列串行执行 if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} }
var task types.SdTask var task types.SdTask
err = s.taskQueue.LPop(&task) err := s.taskQueue.LPop(&task)
if err != nil { if err != nil {
logger.Errorf("taking task with error: %v", err) logger.Errorf("taking task with error: %v", err)
continue continue
} }
logger.Infof("Consuming Task: %+v", task) logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
err = s.Txt2Img(task) err = s.Txt2Img(task)
if err != nil { if err != nil {
logger.Error("绘画任务执行失败:", err) logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 { // update the task progress
s.taskQueue.RPush(task) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
} // release task num
task.RetryCount += 1 atomic.AddInt32(&s.handledTaskNum, -1)
time.Sleep(time.Second * 3)
continue continue
} }
// 更新任务的执行状态 // lock the task until the execute timeout
s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) s.taskStartTimes[task.Id] = time.Now()
// 锁定任务执行通道直到任务超时5分钟 atomic.AddInt32(&s.handledTaskNum, 1)
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
} }
} }
// PushTask 推送任务到队列 // check if current service instance can handle more task
func (s *Service) PushTask(task types.SdTask) { func (s *Service) canHandleTask() bool {
logger.Infof("add a new Stable Diffusion Task: %+v", task) handledNum := atomic.LoadInt32(&s.handledTaskNum)
s.taskQueue.RPush(task) return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
} }
// Txt2Img 文生图 API // Txt2Img 文生图 API
@@ -237,9 +249,8 @@ func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
} }
func (s *Service) callback(data CBReq) { func (s *Service) callback(data CBReq) {
// 释放任务锁 // release task num
s.redis.Del(context.Background(), RunningJobKey) atomic.AddInt32(&s.handledTaskNum, -1)
client := s.Clients.Get(data.SessionId)
if data.Success { // 任务成功 if data.Success { // 任务成功
var job model.SdJob var job model.SdJob
res := s.db.Where("id = ?", data.JobId).First(&job) res := s.db.Where("id = ?", data.JobId).First(&job)
@@ -259,14 +270,16 @@ func (s *Service) callback(data CBReq) {
params.Seed = data.Seed params.Seed = data.Seed
if data.ImageName != "" { // 下载图片 if data.ImageName != "" { // 下载图片
imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName) job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false) if data.Progress == 100 {
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false)
if err != nil { if err != nil {
logger.Error("error with download img: ", err.Error()) logger.Error("error with download img: ", err.Error())
return return
} }
job.ImgURL = imageURL job.ImgURL = imageURL
} }
}
job.Params = utils.JsonEncode(params) job.Params = utils.JsonEncode(params)
res = s.db.Updates(&job) res = s.db.Updates(&job)
@@ -275,38 +288,16 @@ func (s *Service) callback(data CBReq) {
return return
} }
var jobVo vo.SdJob logger.Debugf("绘图进度:%d", data.Progress)
err = utils.CopyObject(job, &jobVo)
if err != nil {
logger.Error("error with copy object: ", err)
return
}
if data.Progress < 100 && data.ImageData != "" {
jobVo.ImgURL = data.ImageData
}
logger.Infof("绘图进度:%d", data.Progress)
// 扣减绘图次数 // 扣减绘图次数
if data.Progress == 100 { if data.Progress == 100 {
s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", jobVo.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) s.db.Model(&model.User{}).Where("id = ? AND img_calls > 0", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
// 推送任务到前端
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
} }
} else { // 任务失败 } else { // 任务失败
logger.Error("任务执行失败:", data.Message) logger.Error("任务执行失败:", data.Message)
// 删除任务 // update the task progress
s.db.Delete(&model.SdJob{Id: uint(data.JobId)}) s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumn("progress", -1)
// 推送消息到前端
if client != nil {
utils.ReplyChunkMessage(client, vo.SdJob{
Id: uint(data.JobId),
Progress: -1,
TaskId: data.TaskId,
})
}
} }
} }

View File

@@ -23,7 +23,7 @@ func NewSnowflake() *Snowflake {
} }
// Next 生成一个新的唯一ID // Next 生成一个新的唯一ID
func (s *Snowflake) Next() (string, error) { func (s *Snowflake) Next(raw bool) (string, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -43,6 +43,9 @@ func (s *Snowflake) Next() (string, error) {
s.lastTimestamp = timestamp s.lastTimestamp = timestamp
id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence) id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence)
if raw {
return fmt.Sprintf("%d", id), nil
}
now := time.Now() now := time.Now()
return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil
} }

View File

@@ -6,13 +6,14 @@ type MidJourneyJob struct {
Id uint `gorm:"primarykey;column:id"` Id uint `gorm:"primarykey;column:id"`
Type string Type string
UserId int UserId int
TaskId string
MessageId string MessageId string
ReferenceId string ReferenceId string
ImgURL string ImgURL string
OrgURL string // 原图地址
Hash string // message hash Hash string // message hash
Progress int Progress int
Prompt string Prompt string
Started bool
CreatedAt time.Time CreatedAt time.Time
} }

View File

@@ -11,7 +11,6 @@ type SdJob struct {
Progress int Progress int
Prompt string Prompt string
Params string Params string
Started bool
CreatedAt time.Time CreatedAt time.Time
} }

View File

@@ -6,12 +6,13 @@ type MidJourneyJob struct {
Id uint `json:"id"` Id uint `json:"id"`
Type string `json:"type"` Type string `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
TaskId string `json:"task_id"`
MessageId string `json:"message_id"` MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"` ReferenceId string `json:"reference_id"`
ImgURL string `json:"img_url"` ImgURL string `json:"img_url"`
OrgURL string `json:"org_url"`
Hash string `json:"hash"` Hash string `json:"hash"`
Progress int `json:"progress"` Progress int `json:"progress"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Started bool `json:"started"`
} }

View File

@@ -15,5 +15,4 @@ type SdJob struct {
Progress int `json:"progress"` Progress int `json:"progress"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Started bool `json:"started"`
} }

View File

@@ -244,12 +244,29 @@
</el-icon> </el-icon>
</el-tooltip> </el-tooltip>
</div> </div>
<el-button type="success" @click="translatePrompt"> <div>
<el-button type="primary" @click="translatePrompt">
<el-icon style="margin-right: 6px;font-size: 18px;"> <el-icon style="margin-right: 6px;font-size: 18px;">
<Refresh/> <Refresh/>
</el-icon> </el-icon>
翻译 翻译
</el-button> </el-button>
<el-tooltip
class="box-item"
effect="light"
raw-content
content="使用 AI 翻译并重写提示词,<br/>增加更多细节,风格等描述"
placement="top-end"
>
<el-button type="success" @click="rewritePrompt">
<el-icon style="margin-right: 6px;font-size: 18px;">
<Refresh/>
</el-icon>
翻译并重写
</el-button>
</el-tooltip>
</div>
</div> </div>
</div> </div>
@@ -432,8 +449,7 @@ import ItemList from "@/components/ItemList.vue";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {checkSession} from "@/action/session"; import {checkSession} from "@/action/session";
import {useRouter} from "vue-router"; import {useRouter} from "vue-router";
import {getSessionId, getUserToken} from "@/store/session"; import {getSessionId} from "@/store/session";
import {removeArrayItem} from "@/utils/libs";
const listBoxHeight = ref(window.innerHeight - 40) const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150) const mjBoxHeight = ref(window.innerHeight - 150)
@@ -504,71 +520,14 @@ const socket = ref(null)
const imgCalls = ref(0) const imgCalls = ref(0)
const loading = ref(false) const loading = ref(false)
const connect = () => { const rewritePrompt = () => {
let host = process.env.VUE_APP_WS_HOST loading.value = true
if (host === '') { httpPost("/api/prompt/rewrite", {"prompt": params.value.prompt}).then(res => {
if (location.protocol === 'https:') { params.value.prompt = res.data
host = 'wss://' + location.host; loading.value = false
} else { }).catch(e => {
host = 'ws://' + location.host; ElMessage.error("翻译失败:" + e.message)
}
}
const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8");
reader.onload = () => {
const data = JSON.parse(String(reader.result));
let isNew = true
if (data.progress === 100) {
for (let i = 0; i < finishedJobs.value.length; i++) {
if (finishedJobs.value[i].id === data.id) {
isNew = false
break
}
}
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
runningJobs.value.splice(i, 1)
break
}
}
if (isNew) {
finishedJobs.value.unshift(data)
}
} else if (data.progress === -1) { // 任务执行失败
ElNotification({
title: '任务执行失败',
message: "提示词:" + data['prompt'],
type: 'error',
}) })
runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
} else {
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
isNew = false
runningJobs.value[i] = data
break
}
}
if (isNew) {
runningJobs.value.push(data)
}
}
}
}
});
_socket.addEventListener('close', () => {
ElMessage.error("Websocket 已经断开,正在重新连接服务器")
connect()
});
} }
const translatePrompt = () => { const translatePrompt = () => {
@@ -576,7 +535,7 @@ const translatePrompt = () => {
httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => { httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => {
params.value.prompt = res.data params.value.prompt = res.data
loading.value = false loading.value = false
}).then(e => { }).catch(e => {
ElMessage.error("翻译失败:" + e.message) ElMessage.error("翻译失败:" + e.message)
}) })
} }
@@ -584,22 +543,10 @@ const translatePrompt = () => {
onMounted(() => { onMounted(() => {
checkSession().then(user => { checkSession().then(user => {
imgCalls.value = user['img_calls'] imgCalls.value = user['img_calls']
// 获取运行中的任务
httpGet(`/api/mj/jobs?status=0&user_id=${user['id']}`).then(res => {
runningJobs.value = res.data
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
// 获取已完成的任务 fetchRunningJobs(user.id)
httpGet(`/api/mj/jobs?status=1&user_id=${user['id']}`).then(res => { fetchFinishJobs(user.id)
finishedJobs.value = res.data
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
// 连接 socket
connect();
}).catch(() => { }).catch(() => {
router.push('/login') router.push('/login')
}); });
@@ -614,6 +561,41 @@ onMounted(() => {
}) })
}) })
// 获取运行中的任务
const fetchRunningJobs = (userId) => {
httpGet(`/api/mj/jobs?status=0&user_id=${userId}`).then(res => {
const jobs = res.data
const _jobs = []
for (let i = 0; i < jobs.length; i++) {
if (jobs[i].progress === -1) {
ElNotification({
title: '任务执行失败',
message: "任务ID" + jobs[i]['task_id'],
type: 'error',
})
continue
}
_jobs.push(jobs[i])
}
runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 5000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
const fetchFinishJobs = (userId) => {
// 获取已完成的任务
httpGet(`/api/mj/jobs?status=1&user_id=${userId}`).then(res => {
finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 5000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
// 切换图片比例 // 切换图片比例
const changeRate = (item) => { const changeRate = (item) => {
params.value.rate = item.value params.value.rate = item.value
@@ -676,7 +658,6 @@ const variation = (index, item) => {
const send = (url, index, item) => { const send = (url, index, item) => {
httpPost(url, { httpPost(url, {
index: index, index: index,
src: "img",
message_id: item.message_id, message_id: item.message_id,
message_hash: item.hash, message_hash: item.hash,
session_id: getSessionId(), session_id: getSessionId(),

View File

@@ -241,7 +241,7 @@
</div> </div>
</div> </div>
<div class="param-line"> <div class="param-line" v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.8)">
<el-input <el-input
v-model="params.prompt" v-model="params.prompt"
:autosize="{ minRows: 4, maxRows: 6 }" :autosize="{ minRows: 4, maxRows: 6 }"
@@ -251,6 +251,30 @@
/> />
</div> </div>
<div style="padding: 10px">
<el-button type="primary" @click="translatePrompt" size="small">
<el-icon style="margin-right: 6px;font-size: 18px;">
<Refresh/>
</el-icon>
翻译
</el-button>
<el-tooltip
class="box-item"
effect="dark"
raw-content
content="使用 AI 翻译并重写提示词,<br/>增加更多细节,风格等描述"
placement="top-end"
>
<el-button type="success" @click="rewritePrompt" size="small">
<el-icon style="margin-right: 6px;font-size: 18px;">
<Refresh/>
</el-icon>
翻译并重写
</el-button>
</el-tooltip>
</div>
<div class="param-line pt"> <div class="param-line pt">
<span>反向提示词</span> <span>反向提示词</span>
<el-tooltip <el-tooltip
@@ -272,12 +296,8 @@
/> />
</div> </div>
<div class="param-line pt"> <div class="param-line" style="padding: 10px">
<el-form-item label="剩余次数"> <el-tag type="success">绘图可用额度{{ imgCalls }}</el-tag>
<template #default>
<el-tag type="info">{{ imgCalls }}</el-tag>
</template>
</el-form-item>
</div> </div>
</el-form> </el-form>
</div> </div>
@@ -478,21 +498,21 @@
<script setup> <script setup>
import {onMounted, ref} from "vue" import {onMounted, ref} from "vue"
import {DocumentCopy, InfoFilled, Orange, Picture} from "@element-plus/icons-vue"; import {DocumentCopy, InfoFilled, Orange, Picture, Refresh} from "@element-plus/icons-vue";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
import {ElMessage, ElNotification} from "element-plus"; import {ElMessage, ElNotification} from "element-plus";
import ItemList from "@/components/ItemList.vue"; import ItemList from "@/components/ItemList.vue";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {checkSession} from "@/action/session"; import {checkSession} from "@/action/session";
import {useRouter} from "vue-router"; import {useRouter} from "vue-router";
import {getSessionId, getUserToken} from "@/store/session"; import {getSessionId} from "@/store/session";
import {removeArrayItem} from "@/utils/libs";
const listBoxHeight = ref(window.innerHeight - 40) const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150) const mjBoxHeight = ref(window.innerHeight - 150)
const fullImgHeight = ref(window.innerHeight - 60) const fullImgHeight = ref(window.innerHeight - 60)
const showTaskDialog = ref(false) const showTaskDialog = ref(false)
const item = ref({}) const item = ref({})
const loading = ref(false)
window.onresize = () => { window.onresize = () => {
listBoxHeight.value = window.innerHeight - 40 listBoxHeight.value = window.innerHeight - 40
@@ -515,116 +535,84 @@ const params = ref({
hd_scale_alg: scaleAlg[0], hd_scale_alg: scaleAlg[0],
hd_steps: 10, hd_steps: 10,
prompt: "", prompt: "",
negative_prompt: "nsfw, paintings, cartoon, anime, sketches, low quality,easynegative,ng_deepnegative _v1 75t,(worst quality:2),(low quality:2),(normalquality:2),lowres,bad anatomy,bad hands,normal quality,((monochrome)),((grayscale)),((watermark))", negative_prompt: "nsfw, paintings,low quality,easynegative,ng_deepnegative ,lowres,bad anatomy,bad hands,bad feet",
}) })
const runningJobs = ref([]) const runningJobs = ref([])
const finishedJobs = ref([]) const finishedJobs = ref([])
const previewImgList = ref([])
const router = useRouter() const router = useRouter()
// 检查是否有画同款的参数 // 检查是否有画同款的参数
const _params = router.currentRoute.value.params["copyParams"] const _params = router.currentRoute.value.params["copyParams"]
if (_params) { if (_params) {
params.value = JSON.parse(_params) params.value = JSON.parse(_params)
} }
const socket = ref(null)
const imgCalls = ref(0) const imgCalls = ref(0)
const connect = () => { const rewritePrompt = () => {
let host = process.env.VUE_APP_WS_HOST loading.value = true
if (host === '') { httpPost("/api/prompt/rewrite", {"prompt": params.value.prompt}).then(res => {
if (location.protocol === 'https:') { params.value.prompt = res.data
host = 'wss://' + location.host; loading.value = false
} else { }).catch(e => {
host = 'ws://' + location.host; ElMessage.error("翻译失败:" + e.message)
}
}
const _socket = new WebSocket(host + `/api/sd/client?session_id=${getSessionId()}&token=${getUserToken()}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8");
reader.onload = () => {
const data = JSON.parse(String(reader.result));
let append = true
if (data.progress === 100) { // 任务已完成
for (let i = 0; i < finishedJobs.value.length; i++) {
if (finishedJobs.value[i].id === data.id) {
append = false
break
}
}
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
runningJobs.value.splice(i, 1)
break
}
}
if (append) {
finishedJobs.value.unshift(data)
}
previewImgList.value.unshift(data["img_url"])
} else if (data.progress === -1) { // 任务执行失败
ElNotification({
title: '任务执行失败',
message: "任务ID" + data['task_id'],
type: 'error',
}) })
runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id) }
} else { // 启动新的任务 const translatePrompt = () => {
for (let i = 0; i < runningJobs.value.length; i++) { loading.value = true
if (runningJobs.value[i].id === data.id) { httpPost("/api/prompt/translate", {"prompt": params.value.prompt}).then(res => {
append = false params.value.prompt = res.data
runningJobs.value[i] = data loading.value = false
break }).catch(e => {
} ElMessage.error("翻译失败:" + e.message)
} })
if (append) {
runningJobs.value.push(data)
}
}
}
}
});
_socket.addEventListener('close', () => {
connect()
});
} }
onMounted(() => { onMounted(() => {
checkSession().then(user => { checkSession().then(user => {
imgCalls.value = user['img_calls'] imgCalls.value = user['img_calls']
// 获取运行中的任务
httpGet(`/api/sd/jobs?status=0&user_id=${user['id']}`).then(res => {
runningJobs.value = res.data
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
// 获取运行中的任务 fetchRunningJobs(user.id)
httpGet(`/api/sd/jobs?status=1&user_id=${user['id']}`).then(res => { fetchFinishJobs(user.id)
finishedJobs.value = res.data
previewImgList.value = []
for (let index in finishedJobs.value) {
previewImgList.value.push(finishedJobs.value[index]["img_url"])
}
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
// 连接 socket
connect();
}).catch(() => { }).catch(() => {
router.push('/login') router.push('/login')
}); });
const fetchRunningJobs = (userId) => {
// 获取运行中的任务
httpGet(`/api/sd/jobs?status=0&user_id=${userId}`).then(res => {
const jobs = res.data
const _jobs = []
for (let i = 0; i < jobs.length; i++) {
if (jobs[i].progress === -1) {
ElNotification({
title: '任务执行失败',
message: "任务ID" + jobs[i]['task_id'],
type: 'error',
})
continue
}
_jobs.push(jobs[i])
}
runningJobs.value = _jobs
setTimeout(() => fetchRunningJobs(userId), 5000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
// 获取已完成的任务
const fetchFinishJobs = (userId) => {
httpGet(`/api/sd/jobs?status=1&user_id=${userId}`).then(res => {
finishedJobs.value = res.data
setTimeout(() => fetchFinishJobs(userId), 5000)
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
}
const clipboard = new Clipboard('.copy-prompt'); const clipboard = new Clipboard('.copy-prompt');
clipboard.on('success', () => { clipboard.on('success', () => {
ElMessage.success("复制成功!"); ElMessage.success("复制成功!");

View File

@@ -101,7 +101,7 @@ const login = function () {
httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => { httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => {
setUserToken(res.data) setUserToken(res.data)
if (prevRoute.path === '') { if (prevRoute.path === '' || prevRoute.path === '/register') {
if (isMobile()) { if (isMobile()) {
router.push('/mobile') router.push('/mobile')
} else { } else {

View File

@@ -127,14 +127,6 @@ import {isMobile} from "@/utils/libs";
import {setUserToken} from "@/store/session"; import {setUserToken} from "@/store/session";
import {checkSession} from "@/action/session"; import {checkSession} from "@/action/session";
checkSession().then(() => {
if (isMobile()) {
router.push('/mobile')
} else {
router.push('/chat')
}
}).catch(() => {
})
const router = useRouter(); const router = useRouter();
const title = ref('ChatGPT-PLUS 用户注册'); const title = ref('ChatGPT-PLUS 用户注册');
const formData = ref({ const formData = ref({