mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
refactor mj service, add mj service pool support
This commit is contained in:
parent
c012f0c4c5
commit
cf758d773e
@ -33,7 +33,7 @@ func NewDefaultConfig() *types.AppConfig {
|
||||
BasePath: "./static/upload",
|
||||
},
|
||||
},
|
||||
MjConfig: types.MidJourneyConfig{Enabled: false},
|
||||
MjConfigs: types.MidJourneyConfig{Enabled: false},
|
||||
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
|
||||
WeChatBot: false,
|
||||
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
||||
|
@ -18,7 +18,7 @@ type AppConfig struct {
|
||||
AesEncryptKey string
|
||||
SmsConfig AliYunSmsConfig // AliYun send message service config
|
||||
OSS OSSConfig // OSS config
|
||||
MjConfig MidJourneyConfig // mj 绘画配置
|
||||
MjConfigs []MidJourneyConfig // mj 绘画配置池子
|
||||
WeChatBot bool // 是否启用微信机器人
|
||||
SdConfig StableDiffusionConfig // sd 绘画配置
|
||||
|
||||
@ -116,7 +116,7 @@ type ChatConfig struct {
|
||||
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
|
||||
ContextDeep int `json:"context_deep"` // 上下文深度
|
||||
DallApiURL string `json:"dall_api_url"` // dall-e3 绘图 API 地址
|
||||
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
|
||||
DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
|
||||
}
|
||||
|
||||
type Platform string
|
||||
|
@ -11,28 +11,15 @@ const (
|
||||
TaskImage = TaskType("image")
|
||||
TaskUpscale = TaskType("upscale")
|
||||
TaskVariation = TaskType("variation")
|
||||
TaskTxt2Img = TaskType("text2img")
|
||||
)
|
||||
|
||||
// TaskSrc 任务来源
|
||||
type TaskSrc string
|
||||
|
||||
const (
|
||||
TaskSrcChat = TaskSrc("chat") // 来自聊天页面
|
||||
TaskSrcImg = TaskSrc("img") // 专业绘画页面
|
||||
)
|
||||
|
||||
// MjTask MidJourney 任务
|
||||
type MjTask struct {
|
||||
Id int `json:"id"`
|
||||
SessionId string `json:"session_id"`
|
||||
Src TaskSrc `json:"src"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
ChatId string `json:"chat_id,omitempty"`
|
||||
RoleId int `json:"role_id,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
MessageId string `json:"message_id,omitempty"`
|
||||
MessageHash string `json:"message_hash,omitempty"`
|
||||
@ -42,7 +29,6 @@ type MjTask struct {
|
||||
type SdTask struct {
|
||||
Id int `json:"id"` // job 数据库ID
|
||||
SessionId string `json:"session_id"`
|
||||
Src TaskSrc `json:"src"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
|
@ -12,9 +12,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@ -40,20 +38,6 @@ func NewMidJourneyHandler(
|
||||
return &h
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
sessionId := c.Query("session_id")
|
||||
client := types.NewWsClient(ws)
|
||||
h.mjService.Clients.Put(sessionId, client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
||||
user, err := utils.GetLoginUser(c, h.db)
|
||||
if err != nil {
|
||||
@ -72,7 +56,7 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
if !h.App.Config.MjConfig.Enabled {
|
||||
if !h.App.Config.MjConfigs[0].Enabled {
|
||||
resp.ERROR(c, "MidJourney service is disabled")
|
||||
return
|
||||
}
|
||||
|
19
api/main.go
19
api/main.go
@ -165,23 +165,6 @@ func main() {
|
||||
|
||||
// MidJourney 机器人
|
||||
fx.Provide(mj.NewBot),
|
||||
fx.Provide(mj.NewClient),
|
||||
fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
|
||||
if config.MjConfig.Enabled {
|
||||
err := bot.Run()
|
||||
if err != nil {
|
||||
log.Fatal("MidJourney 服务启动失败:", err)
|
||||
}
|
||||
}
|
||||
}),
|
||||
fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) {
|
||||
if config.MjConfig.Enabled {
|
||||
go func() {
|
||||
mjService.Run()
|
||||
}()
|
||||
}
|
||||
}),
|
||||
|
||||
// Stable Diffusion 机器人
|
||||
fx.Provide(sd.NewService),
|
||||
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
|
||||
@ -256,13 +239,11 @@ func main() {
|
||||
group.POST("upscale", h.Upscale)
|
||||
group.POST("variation", h.Variation)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.Any("client", h.Client)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||
group := s.Engine.Group("/api/sd")
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.Any("client", h.Client)
|
||||
}),
|
||||
|
||||
// 管理后台控制器
|
||||
|
@ -23,7 +23,7 @@ type Bot struct {
|
||||
}
|
||||
|
||||
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
||||
discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
|
||||
discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -41,7 +41,7 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
||||
}
|
||||
|
||||
return &Bot{
|
||||
config: &config.MjConfig,
|
||||
config: &config.MjConfigs,
|
||||
bot: discord,
|
||||
service: service,
|
||||
}, nil
|
||||
|
@ -2,6 +2,7 @@ package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"time"
|
||||
@ -14,13 +15,13 @@ type Client struct {
|
||||
config *types.MidJourneyConfig
|
||||
}
|
||||
|
||||
func NewClient(config *types.AppConfig) *Client {
|
||||
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
|
||||
client := req.C().SetTimeout(10 * time.Second)
|
||||
// set proxy URL
|
||||
if config.ProxyURL != "" {
|
||||
client.SetProxyURL(config.ProxyURL)
|
||||
if utils.IsEmptyValue(proxy) {
|
||||
client.SetProxyURL(proxy)
|
||||
}
|
||||
return &Client{client: client, config: &config.MjConfig}
|
||||
return &Client{client: client, config: config}
|
||||
}
|
||||
|
||||
func (c *Client) Imagine(prompt string) error {
|
||||
|
38
api/service/mj/pool.go
Normal file
38
api/service/mj/pool.go
Normal file
@ -0,0 +1,38 @@
|
||||
package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ServicePool Mj service pool
|
||||
type ServicePool struct {
|
||||
services []Service
|
||||
taskQueue *store.RedisQueue
|
||||
}
|
||||
|
||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||
// create mj client and service
|
||||
for _, config := range appConfig.MjConfigs {
|
||||
if config.Enabled == false {
|
||||
continue
|
||||
}
|
||||
// create mj client
|
||||
client := NewClient(&config, appConfig.ProxyURL)
|
||||
|
||||
// create mj service
|
||||
service := NewService()
|
||||
}
|
||||
|
||||
return &ServicePool{
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ServicePool) PushTask(task types.MjTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
p.taskQueue.RPush(task)
|
||||
}
|
@ -2,63 +2,63 @@ package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MJ 绘画服务
|
||||
|
||||
const RunningJobKey = "MidJourney_Running_Job"
|
||||
|
||||
// Service MJ 绘画服务
|
||||
type Service struct {
|
||||
client *Client // MJ 客户端
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
|
||||
ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
|
||||
proxyURL string
|
||||
name string // service name
|
||||
client *Client // MJ client
|
||||
taskQueue *store.RedisQueue
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
proxyURL string
|
||||
maxHandleTaskNum int32 // max task number current service can handle
|
||||
handledTaskNum int32 // already handled task number
|
||||
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||
taskTimeout int64
|
||||
snowflake *service.Snowflake
|
||||
}
|
||||
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
||||
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
||||
return &Service{
|
||||
redis: redisCli,
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
client: client,
|
||||
uploadManager: manager,
|
||||
Clients: types.NewLMap[string, *types.WsClient](),
|
||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||
proxyURL: config.ProxyURL,
|
||||
name: name,
|
||||
db: db,
|
||||
taskQueue: queue,
|
||||
client: client,
|
||||
uploadManager: manager,
|
||||
taskTimeout: timeout,
|
||||
proxyURL: config.ProxyURL,
|
||||
taskStartTimes: make(map[int]time.Time, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Info("Starting MidJourney job consumer.")
|
||||
ctx := context.Background()
|
||||
logger.Infof("Starting MidJourney job consumer for %s", s.name)
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||||
if err == nil { // 队列串行执行
|
||||
s.checkTasks()
|
||||
if !s.canHandleTask() {
|
||||
// current service is full, can not handle more task
|
||||
// waiting for running task finish
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
var task types.MjTask
|
||||
err = s.taskQueue.LPop(&task)
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
|
||||
logger.Infof("handle a new MidJourney task: %+v", task)
|
||||
switch task.Type {
|
||||
case types.TaskImage:
|
||||
err = s.client.Imagine(task.Prompt)
|
||||
@ -70,50 +70,40 @@ func (s *Service) Run() {
|
||||
case types.TaskVariation:
|
||||
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
// 删除任务
|
||||
s.db.Delete(&model.MidJourneyJob{Id: uint(task.Id)})
|
||||
// 推送任务到前端
|
||||
client := s.Clients.Get(task.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, vo.MidJourneyJob{
|
||||
Type: task.Type.String(),
|
||||
UserId: task.UserId,
|
||||
MessageId: task.MessageId,
|
||||
Progress: -1,
|
||||
Prompt: task.Prompt,
|
||||
})
|
||||
}
|
||||
// update the task progress
|
||||
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务的执行状态
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
// lock the task until the execute timeout
|
||||
s.taskStartTimes[task.Id] = time.Now()
|
||||
atomic.AddInt32(&s.handledTaskNum, 1)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.MjTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
// check if current service instance can handle more task
|
||||
func (s *Service) canHandleTask() bool {
|
||||
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
||||
return handledNum < s.maxHandleTaskNum
|
||||
}
|
||||
|
||||
func (s *Service) checkTasks() {
|
||||
for k, t := range s.taskStartTimes {
|
||||
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||
delete(s.taskStartTimes, k)
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Notify(data CBReq) {
|
||||
taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
|
||||
if err != nil { // 过期任务,丢弃
|
||||
logger.Warn("任务已过期:", err)
|
||||
return
|
||||
}
|
||||
|
||||
var task types.MjTask
|
||||
err = utils.JsonDecode(taskString, &task)
|
||||
if err != nil { // 非标准任务,丢弃
|
||||
logger.Warn("任务解析失败:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// extract the task ID
|
||||
split := strings.Split(data.Prompt, " ")
|
||||
var job model.MidJourneyJob
|
||||
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
||||
if res.Error == nil && data.Status == Finished {
|
||||
@ -121,137 +111,37 @@ func (s *Service) Notify(data CBReq) {
|
||||
return
|
||||
}
|
||||
|
||||
if task.Src == types.TaskSrcImg { // 绘画任务
|
||||
var job model.MidJourneyJob
|
||||
res := s.db.Where("id = ?", task.Id).First(&job)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", res.Error)
|
||||
return
|
||||
}
|
||||
job.MessageId = data.MessageId
|
||||
job.ReferenceId = data.ReferenceId
|
||||
job.Progress = data.Progress
|
||||
job.Prompt = data.Prompt
|
||||
job.Hash = data.Image.Hash
|
||||
res = s.db.Where("task_id = ?", split[0]).First(&job)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", res.Error)
|
||||
return
|
||||
}
|
||||
job.MessageId = data.MessageId
|
||||
job.ReferenceId = data.ReferenceId
|
||||
job.Progress = data.Progress
|
||||
job.Prompt = data.Prompt
|
||||
job.Hash = data.Image.Hash
|
||||
job.OrgURL = data.Image.URL // save origin image
|
||||
|
||||
// 任务完成,将最终的图片下载下来
|
||||
if data.Progress == 100 {
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
}
|
||||
job.ImgURL = imgURL
|
||||
} else {
|
||||
// 临时图片直接保存,访问的时候使用代理进行转发
|
||||
job.ImgURL = data.Image.URL
|
||||
}
|
||||
res = s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update job: ", res.Error)
|
||||
return
|
||||
}
|
||||
// upload image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
}
|
||||
job.ImgURL = imgURL
|
||||
|
||||
var jobVo vo.MidJourneyJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
if data.Progress < 100 {
|
||||
image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
|
||||
if err == nil {
|
||||
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
|
||||
// 推送任务到前端
|
||||
client := s.Clients.Get(task.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
}
|
||||
|
||||
} else if task.Src == types.TaskSrcChat { // 聊天任务
|
||||
wsClient := s.ChatClients.Get(task.SessionId)
|
||||
if data.Status == Finished {
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
// download image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
data.Image.URL = imgURL
|
||||
message := model.HistoryMessage{
|
||||
UserId: uint(task.UserId),
|
||||
ChatId: task.ChatId,
|
||||
RoleId: uint(task.RoleId),
|
||||
Type: types.MjMsg,
|
||||
Icon: task.Icon,
|
||||
Content: utils.JsonEncode(data),
|
||||
Tokens: 0,
|
||||
UseContext: false,
|
||||
}
|
||||
res = tx.Create(&message)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
// save the job
|
||||
job.UserId = task.UserId
|
||||
job.Type = task.Type.String()
|
||||
job.MessageId = data.MessageId
|
||||
job.ReferenceId = data.ReferenceId
|
||||
job.Prompt = data.Prompt
|
||||
job.ImgURL = imgURL
|
||||
job.Progress = data.Progress
|
||||
job.Hash = data.Image.Hash
|
||||
job.CreatedAt = time.Now()
|
||||
res = tx.Create(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database: ", err)
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
if wsClient == nil { // 客户端断线,则丢弃
|
||||
logger.Errorf("Client is offline: %+v", data)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Status == Finished {
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
||||
// 本次绘画完毕,移除客户端
|
||||
s.ChatClients.Delete(task.SessionId)
|
||||
} else {
|
||||
// 使用代理临时转发图片
|
||||
if data.Image.URL != "" {
|
||||
image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
|
||||
if err == nil {
|
||||
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||
}
|
||||
res = s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update job: ", res.Error)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新用户剩余绘图次数
|
||||
// TODO: 放大图片是否需要消耗绘图次数?
|
||||
if data.Status == Finished {
|
||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
// 解除任务锁定
|
||||
s.redis.Del(context.Background(), RunningJobKey)
|
||||
// update user's img calls
|
||||
s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
// release lock task
|
||||
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func (s *HuPiPayService) Sign(params map[string]string) string {
|
||||
var data string
|
||||
keys := make([]string, 0, 0)
|
||||
params["appid"] = s.appId
|
||||
for key, _ := range params {
|
||||
for key := range params {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
@ -6,13 +6,14 @@ type MidJourneyJob struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
Type string
|
||||
UserId int
|
||||
TaskId string
|
||||
MessageId string
|
||||
ReferenceId string
|
||||
ImgURL string
|
||||
OrgURL string // 原图地址
|
||||
Hash string // message hash
|
||||
Progress int
|
||||
Prompt string
|
||||
Started bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
|
@ -9,9 +9,9 @@ type MidJourneyJob struct {
|
||||
MessageId string `json:"message_id"`
|
||||
ReferenceId string `json:"reference_id"`
|
||||
ImgURL string `json:"img_url"`
|
||||
OrgURL string `json:"org_url"`
|
||||
Hash string `json:"hash"`
|
||||
Progress int `json:"progress"`
|
||||
Prompt string `json:"prompt"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Started bool `json:"started"`
|
||||
}
|
||||
|
@ -504,72 +504,72 @@ const socket = ref(null)
|
||||
const imgCalls = ref(0)
|
||||
const loading = ref(false)
|
||||
|
||||
const connect = () => {
|
||||
let host = process.env.VUE_APP_WS_HOST
|
||||
if (host === '') {
|
||||
if (location.protocol === 'https:') {
|
||||
host = 'wss://' + location.host;
|
||||
} else {
|
||||
host = 'ws://' + location.host;
|
||||
}
|
||||
}
|
||||
const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
|
||||
_socket.addEventListener('open', () => {
|
||||
socket.value = _socket;
|
||||
});
|
||||
|
||||
_socket.addEventListener('message', event => {
|
||||
if (event.data instanceof Blob) {
|
||||
const reader = new FileReader();
|
||||
reader.readAsText(event.data, "UTF-8");
|
||||
reader.onload = () => {
|
||||
const data = JSON.parse(String(reader.result));
|
||||
let isNew = true
|
||||
if (data.progress === 100) {
|
||||
for (let i = 0; i < finishedJobs.value.length; i++) {
|
||||
if (finishedJobs.value[i].id === data.id) {
|
||||
isNew = false
|
||||
break
|
||||
}
|
||||
}
|
||||
for (let i = 0; i < runningJobs.value.length; i++) {
|
||||
if (runningJobs.value[i].id === data.id) {
|
||||
runningJobs.value.splice(i, 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
if (isNew) {
|
||||
finishedJobs.value.unshift(data)
|
||||
}
|
||||
} else if (data.progress === -1) { // 任务执行失败
|
||||
ElNotification({
|
||||
title: '任务执行失败',
|
||||
message: "提示词:" + data['prompt'],
|
||||
type: 'error',
|
||||
})
|
||||
runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
|
||||
|
||||
} else {
|
||||
for (let i = 0; i < runningJobs.value.length; i++) {
|
||||
if (runningJobs.value[i].id === data.id) {
|
||||
isNew = false
|
||||
runningJobs.value[i] = data
|
||||
break
|
||||
}
|
||||
}
|
||||
if (isNew) {
|
||||
runningJobs.value.push(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
_socket.addEventListener('close', () => {
|
||||
ElMessage.error("Websocket 已经断开,正在重新连接服务器")
|
||||
connect()
|
||||
});
|
||||
}
|
||||
// const connect = () => {
|
||||
// let host = process.env.VUE_APP_WS_HOST
|
||||
// if (host === '') {
|
||||
// if (location.protocol === 'https:') {
|
||||
// host = 'wss://' + location.host;
|
||||
// } else {
|
||||
// host = 'ws://' + location.host;
|
||||
// }
|
||||
// }
|
||||
// const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
|
||||
// _socket.addEventListener('open', () => {
|
||||
// socket.value = _socket;
|
||||
// });
|
||||
//
|
||||
// _socket.addEventListener('message', event => {
|
||||
// if (event.data instanceof Blob) {
|
||||
// const reader = new FileReader();
|
||||
// reader.readAsText(event.data, "UTF-8");
|
||||
// reader.onload = () => {
|
||||
// const data = JSON.parse(String(reader.result));
|
||||
// let isNew = true
|
||||
// if (data.progress === 100) {
|
||||
// for (let i = 0; i < finishedJobs.value.length; i++) {
|
||||
// if (finishedJobs.value[i].id === data.id) {
|
||||
// isNew = false
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// for (let i = 0; i < runningJobs.value.length; i++) {
|
||||
// if (runningJobs.value[i].id === data.id) {
|
||||
// runningJobs.value.splice(i, 1)
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// if (isNew) {
|
||||
// finishedJobs.value.unshift(data)
|
||||
// }
|
||||
// } else if (data.progress === -1) { // 任务执行失败
|
||||
// ElNotification({
|
||||
// title: '任务执行失败',
|
||||
// message: "提示词:" + data['prompt'],
|
||||
// type: 'error',
|
||||
// })
|
||||
// runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
|
||||
//
|
||||
// } else {
|
||||
// for (let i = 0; i < runningJobs.value.length; i++) {
|
||||
// if (runningJobs.value[i].id === data.id) {
|
||||
// isNew = false
|
||||
// runningJobs.value[i] = data
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// if (isNew) {
|
||||
// runningJobs.value.push(data)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
//
|
||||
// _socket.addEventListener('close', () => {
|
||||
// ElMessage.error("Websocket 已经断开,正在重新连接服务器")
|
||||
// connect()
|
||||
// });
|
||||
// }
|
||||
|
||||
const translatePrompt = () => {
|
||||
loading.value = true
|
||||
|
Loading…
Reference in New Issue
Block a user