mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
refactor mj service, add mj service pool support
This commit is contained in:
parent
c012f0c4c5
commit
cf758d773e
@ -33,7 +33,7 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
BasePath: "./static/upload",
|
BasePath: "./static/upload",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
MjConfig: types.MidJourneyConfig{Enabled: false},
|
MjConfigs: types.MidJourneyConfig{Enabled: false},
|
||||||
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
|
SdConfig: types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
|
||||||
WeChatBot: false,
|
WeChatBot: false,
|
||||||
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
||||||
|
@ -18,7 +18,7 @@ type AppConfig struct {
|
|||||||
AesEncryptKey string
|
AesEncryptKey string
|
||||||
SmsConfig AliYunSmsConfig // AliYun send message service config
|
SmsConfig AliYunSmsConfig // AliYun send message service config
|
||||||
OSS OSSConfig // OSS config
|
OSS OSSConfig // OSS config
|
||||||
MjConfig MidJourneyConfig // mj 绘画配置
|
MjConfigs []MidJourneyConfig // mj 绘画配置池子
|
||||||
WeChatBot bool // 是否启用微信机器人
|
WeChatBot bool // 是否启用微信机器人
|
||||||
SdConfig StableDiffusionConfig // sd 绘画配置
|
SdConfig StableDiffusionConfig // sd 绘画配置
|
||||||
|
|
||||||
|
@ -11,28 +11,15 @@ const (
|
|||||||
TaskImage = TaskType("image")
|
TaskImage = TaskType("image")
|
||||||
TaskUpscale = TaskType("upscale")
|
TaskUpscale = TaskType("upscale")
|
||||||
TaskVariation = TaskType("variation")
|
TaskVariation = TaskType("variation")
|
||||||
TaskTxt2Img = TaskType("text2img")
|
|
||||||
)
|
|
||||||
|
|
||||||
// TaskSrc 任务来源
|
|
||||||
type TaskSrc string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TaskSrcChat = TaskSrc("chat") // 来自聊天页面
|
|
||||||
TaskSrcImg = TaskSrc("img") // 专业绘画页面
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MjTask MidJourney 任务
|
// MjTask MidJourney 任务
|
||||||
type MjTask struct {
|
type MjTask struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Src TaskSrc `json:"src"`
|
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
ChatId string `json:"chat_id,omitempty"`
|
|
||||||
RoleId int `json:"role_id,omitempty"`
|
|
||||||
Icon string `json:"icon,omitempty"`
|
|
||||||
Index int `json:"index,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
MessageId string `json:"message_id,omitempty"`
|
MessageId string `json:"message_id,omitempty"`
|
||||||
MessageHash string `json:"message_hash,omitempty"`
|
MessageHash string `json:"message_hash,omitempty"`
|
||||||
@ -42,7 +29,6 @@ type MjTask struct {
|
|||||||
type SdTask struct {
|
type SdTask struct {
|
||||||
Id int `json:"id"` // job 数据库ID
|
Id int `json:"id"` // job 数据库ID
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Src TaskSrc `json:"src"`
|
|
||||||
Type TaskType `json:"type"`
|
Type TaskType `json:"type"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
@ -12,9 +12,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -40,20 +38,6 @@ func NewMidJourneyHandler(
|
|||||||
return &h
|
return &h
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client WebSocket 客户端,用于通知任务状态变更
|
|
||||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
|
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionId := c.Query("session_id")
|
|
||||||
client := types.NewWsClient(ws)
|
|
||||||
h.mjService.Clients.Put(sessionId, client)
|
|
||||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := utils.GetLoginUser(c, h.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -72,7 +56,7 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
|||||||
|
|
||||||
// Image 创建一个绘画任务
|
// Image 创建一个绘画任务
|
||||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||||
if !h.App.Config.MjConfig.Enabled {
|
if !h.App.Config.MjConfigs[0].Enabled {
|
||||||
resp.ERROR(c, "MidJourney service is disabled")
|
resp.ERROR(c, "MidJourney service is disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
19
api/main.go
19
api/main.go
@ -165,23 +165,6 @@ func main() {
|
|||||||
|
|
||||||
// MidJourney 机器人
|
// MidJourney 机器人
|
||||||
fx.Provide(mj.NewBot),
|
fx.Provide(mj.NewBot),
|
||||||
fx.Provide(mj.NewClient),
|
|
||||||
fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
|
|
||||||
if config.MjConfig.Enabled {
|
|
||||||
err := bot.Run()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("MidJourney 服务启动失败:", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) {
|
|
||||||
if config.MjConfig.Enabled {
|
|
||||||
go func() {
|
|
||||||
mjService.Run()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
|
|
||||||
// Stable Diffusion 机器人
|
// Stable Diffusion 机器人
|
||||||
fx.Provide(sd.NewService),
|
fx.Provide(sd.NewService),
|
||||||
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
|
fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
|
||||||
@ -256,13 +239,11 @@ func main() {
|
|||||||
group.POST("upscale", h.Upscale)
|
group.POST("upscale", h.Upscale)
|
||||||
group.POST("variation", h.Variation)
|
group.POST("variation", h.Variation)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
group.Any("client", h.Client)
|
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||||
group := s.Engine.Group("/api/sd")
|
group := s.Engine.Group("/api/sd")
|
||||||
group.POST("image", h.Image)
|
group.POST("image", h.Image)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
group.Any("client", h.Client)
|
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 管理后台控制器
|
// 管理后台控制器
|
||||||
|
@ -23,7 +23,7 @@ type Bot struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
||||||
discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
|
discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -41,7 +41,7 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Bot{
|
return &Bot{
|
||||||
config: &config.MjConfig,
|
config: &config.MjConfigs,
|
||||||
bot: discord,
|
bot: discord,
|
||||||
service: service,
|
service: service,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -2,6 +2,7 @@ package mj
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
"time"
|
"time"
|
||||||
@ -14,13 +15,13 @@ type Client struct {
|
|||||||
config *types.MidJourneyConfig
|
config *types.MidJourneyConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(config *types.AppConfig) *Client {
|
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
|
||||||
client := req.C().SetTimeout(10 * time.Second)
|
client := req.C().SetTimeout(10 * time.Second)
|
||||||
// set proxy URL
|
// set proxy URL
|
||||||
if config.ProxyURL != "" {
|
if utils.IsEmptyValue(proxy) {
|
||||||
client.SetProxyURL(config.ProxyURL)
|
client.SetProxyURL(proxy)
|
||||||
}
|
}
|
||||||
return &Client{client: client, config: &config.MjConfig}
|
return &Client{client: client, config: config}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Imagine(prompt string) error {
|
func (c *Client) Imagine(prompt string) error {
|
||||||
|
38
api/service/mj/pool.go
Normal file
38
api/service/mj/pool.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/service/oss"
|
||||||
|
"chatplus/store"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServicePool Mj service pool
|
||||||
|
type ServicePool struct {
|
||||||
|
services []Service
|
||||||
|
taskQueue *store.RedisQueue
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
|
||||||
|
// create mj client and service
|
||||||
|
for _, config := range appConfig.MjConfigs {
|
||||||
|
if config.Enabled == false {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// create mj client
|
||||||
|
client := NewClient(&config, appConfig.ProxyURL)
|
||||||
|
|
||||||
|
// create mj service
|
||||||
|
service := NewService()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ServicePool{
|
||||||
|
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ServicePool) PushTask(task types.MjTask) {
|
||||||
|
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||||
|
p.taskQueue.RPush(task)
|
||||||
|
}
|
@ -2,63 +2,63 @@ package mj
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/service"
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MJ 绘画服务
|
// Service MJ 绘画服务
|
||||||
|
|
||||||
const RunningJobKey = "MidJourney_Running_Job"
|
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
client *Client // MJ 客户端
|
name string // service name
|
||||||
|
client *Client // MJ client
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
redis *redis.Client
|
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
uploadManager *oss.UploaderManager
|
uploadManager *oss.UploaderManager
|
||||||
Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
|
|
||||||
ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
|
|
||||||
proxyURL string
|
proxyURL string
|
||||||
|
maxHandleTaskNum int32 // max task number current service can handle
|
||||||
|
handledTaskNum int32 // already handled task number
|
||||||
|
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
||||||
|
taskTimeout int64
|
||||||
|
snowflake *service.Snowflake
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
redis: redisCli,
|
name: name,
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
taskQueue: queue,
|
||||||
client: client,
|
client: client,
|
||||||
uploadManager: manager,
|
uploadManager: manager,
|
||||||
Clients: types.NewLMap[string, *types.WsClient](),
|
taskTimeout: timeout,
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
|
||||||
proxyURL: config.ProxyURL,
|
proxyURL: config.ProxyURL,
|
||||||
|
taskStartTimes: make(map[int]time.Time, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Info("Starting MidJourney job consumer.")
|
logger.Infof("Starting MidJourney job consumer for %s", s.name)
|
||||||
ctx := context.Background()
|
|
||||||
for {
|
for {
|
||||||
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
s.checkTasks()
|
||||||
if err == nil { // 队列串行执行
|
if !s.canHandleTask() {
|
||||||
|
// current service is full, can not handle more task
|
||||||
|
// waiting for running task finish
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var task types.MjTask
|
var task types.MjTask
|
||||||
err = s.taskQueue.LPop(&task)
|
err := s.taskQueue.LPop(&task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("taking task with error: %v", err)
|
logger.Errorf("taking task with error: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logger.Infof("Consuming Task: %+v", task)
|
|
||||||
|
logger.Infof("handle a new MidJourney task: %+v", task)
|
||||||
switch task.Type {
|
switch task.Type {
|
||||||
case types.TaskImage:
|
case types.TaskImage:
|
||||||
err = s.client.Imagine(task.Prompt)
|
err = s.client.Imagine(task.Prompt)
|
||||||
@ -70,50 +70,40 @@ func (s *Service) Run() {
|
|||||||
case types.TaskVariation:
|
case types.TaskVariation:
|
||||||
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
|
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("绘画任务执行失败:", err)
|
logger.Error("绘画任务执行失败:", err)
|
||||||
// 删除任务
|
// update the task progress
|
||||||
s.db.Delete(&model.MidJourneyJob{Id: uint(task.Id)})
|
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
||||||
// 推送任务到前端
|
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||||
client := s.Clients.Get(task.SessionId)
|
|
||||||
if client != nil {
|
|
||||||
utils.ReplyChunkMessage(client, vo.MidJourneyJob{
|
|
||||||
Type: task.Type.String(),
|
|
||||||
UserId: task.UserId,
|
|
||||||
MessageId: task.MessageId,
|
|
||||||
Progress: -1,
|
|
||||||
Prompt: task.Prompt,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新任务的执行状态
|
// lock the task until the execute timeout
|
||||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
s.taskStartTimes[task.Id] = time.Now()
|
||||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
atomic.AddInt32(&s.handledTaskNum, 1)
|
||||||
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) PushTask(task types.MjTask) {
|
// check if current service instance can handle more task
|
||||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
func (s *Service) canHandleTask() bool {
|
||||||
s.taskQueue.RPush(task)
|
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
||||||
|
return handledNum < s.maxHandleTaskNum
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) checkTasks() {
|
||||||
|
for k, t := range s.taskStartTimes {
|
||||||
|
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
||||||
|
delete(s.taskStartTimes, k)
|
||||||
|
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Notify(data CBReq) {
|
func (s *Service) Notify(data CBReq) {
|
||||||
taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
|
// extract the task ID
|
||||||
if err != nil { // 过期任务,丢弃
|
split := strings.Split(data.Prompt, " ")
|
||||||
logger.Warn("任务已过期:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.MjTask
|
|
||||||
err = utils.JsonDecode(taskString, &task)
|
|
||||||
if err != nil { // 非标准任务,丢弃
|
|
||||||
logger.Warn("任务解析失败:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var job model.MidJourneyJob
|
var job model.MidJourneyJob
|
||||||
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
||||||
if res.Error == nil && data.Status == Finished {
|
if res.Error == nil && data.Status == Finished {
|
||||||
@ -121,9 +111,7 @@ func (s *Service) Notify(data CBReq) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if task.Src == types.TaskSrcImg { // 绘画任务
|
res = s.db.Where("task_id = ?", split[0]).First(&job)
|
||||||
var job model.MidJourneyJob
|
|
||||||
res := s.db.Where("id = ?", task.Id).First(&job)
|
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Warn("非法任务:", res.Error)
|
logger.Warn("非法任务:", res.Error)
|
||||||
return
|
return
|
||||||
@ -133,125 +121,27 @@ func (s *Service) Notify(data CBReq) {
|
|||||||
job.Progress = data.Progress
|
job.Progress = data.Progress
|
||||||
job.Prompt = data.Prompt
|
job.Prompt = data.Prompt
|
||||||
job.Hash = data.Image.Hash
|
job.Hash = data.Image.Hash
|
||||||
|
job.OrgURL = data.Image.URL // save origin image
|
||||||
|
|
||||||
// 任务完成,将最终的图片下载下来
|
// upload image
|
||||||
if data.Progress == 100 {
|
|
||||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error with download img: ", err.Error())
|
logger.Error("error with download img: ", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
job.ImgURL = imgURL
|
job.ImgURL = imgURL
|
||||||
} else {
|
|
||||||
// 临时图片直接保存,访问的时候使用代理进行转发
|
|
||||||
job.ImgURL = data.Image.URL
|
|
||||||
}
|
|
||||||
res = s.db.Updates(&job)
|
res = s.db.Updates(&job)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("error with update job: ", res.Error)
|
logger.Error("error with update job: ", res.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobVo vo.MidJourneyJob
|
|
||||||
err := utils.CopyObject(job, &jobVo)
|
|
||||||
if err == nil {
|
|
||||||
if data.Progress < 100 {
|
|
||||||
image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
|
|
||||||
if err == nil {
|
|
||||||
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 推送任务到前端
|
|
||||||
client := s.Clients.Get(task.SessionId)
|
|
||||||
if client != nil {
|
|
||||||
utils.ReplyChunkMessage(client, jobVo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if task.Src == types.TaskSrcChat { // 聊天任务
|
|
||||||
wsClient := s.ChatClients.Get(task.SessionId)
|
|
||||||
if data.Status == Finished {
|
if data.Status == Finished {
|
||||||
if wsClient != nil && data.ReferenceId != "" {
|
// update user's img calls
|
||||||
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
|
s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||||
utils.ReplyMessage(wsClient, content)
|
// release lock task
|
||||||
}
|
atomic.AddInt32(&s.handledTaskNum, -1)
|
||||||
// download image
|
|
||||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
if wsClient != nil && data.ReferenceId != "" {
|
|
||||||
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
|
||||||
utils.ReplyMessage(wsClient, content)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tx := s.db.Begin()
|
|
||||||
data.Image.URL = imgURL
|
|
||||||
message := model.HistoryMessage{
|
|
||||||
UserId: uint(task.UserId),
|
|
||||||
ChatId: task.ChatId,
|
|
||||||
RoleId: uint(task.RoleId),
|
|
||||||
Type: types.MjMsg,
|
|
||||||
Icon: task.Icon,
|
|
||||||
Content: utils.JsonEncode(data),
|
|
||||||
Tokens: 0,
|
|
||||||
UseContext: false,
|
|
||||||
}
|
|
||||||
res = tx.Create(&message)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("error with update database: ", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// save the job
|
|
||||||
job.UserId = task.UserId
|
|
||||||
job.Type = task.Type.String()
|
|
||||||
job.MessageId = data.MessageId
|
|
||||||
job.ReferenceId = data.ReferenceId
|
|
||||||
job.Prompt = data.Prompt
|
|
||||||
job.ImgURL = imgURL
|
|
||||||
job.Progress = data.Progress
|
|
||||||
job.Hash = data.Image.Hash
|
|
||||||
job.CreatedAt = time.Now()
|
|
||||||
res = tx.Create(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("error with update database: ", err)
|
|
||||||
tx.Rollback()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
if wsClient == nil { // 客户端断线,则丢弃
|
|
||||||
logger.Errorf("Client is offline: %+v", data)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.Status == Finished {
|
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
|
||||||
// 本次绘画完毕,移除客户端
|
|
||||||
s.ChatClients.Delete(task.SessionId)
|
|
||||||
} else {
|
|
||||||
// 使用代理临时转发图片
|
|
||||||
if data.Image.URL != "" {
|
|
||||||
image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
|
|
||||||
if err == nil {
|
|
||||||
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新用户剩余绘图次数
|
|
||||||
// TODO: 放大图片是否需要消耗绘图次数?
|
|
||||||
if data.Status == Finished {
|
|
||||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
|
||||||
// 解除任务锁定
|
|
||||||
s.redis.Del(context.Background(), RunningJobKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ func (s *HuPiPayService) Sign(params map[string]string) string {
|
|||||||
var data string
|
var data string
|
||||||
keys := make([]string, 0, 0)
|
keys := make([]string, 0, 0)
|
||||||
params["appid"] = s.appId
|
params["appid"] = s.appId
|
||||||
for key, _ := range params {
|
for key := range params {
|
||||||
keys = append(keys, key)
|
keys = append(keys, key)
|
||||||
}
|
}
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
@ -6,13 +6,14 @@ type MidJourneyJob struct {
|
|||||||
Id uint `gorm:"primarykey;column:id"`
|
Id uint `gorm:"primarykey;column:id"`
|
||||||
Type string
|
Type string
|
||||||
UserId int
|
UserId int
|
||||||
|
TaskId string
|
||||||
MessageId string
|
MessageId string
|
||||||
ReferenceId string
|
ReferenceId string
|
||||||
ImgURL string
|
ImgURL string
|
||||||
|
OrgURL string // 原图地址
|
||||||
Hash string // message hash
|
Hash string // message hash
|
||||||
Progress int
|
Progress int
|
||||||
Prompt string
|
Prompt string
|
||||||
Started bool
|
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ type MidJourneyJob struct {
|
|||||||
MessageId string `json:"message_id"`
|
MessageId string `json:"message_id"`
|
||||||
ReferenceId string `json:"reference_id"`
|
ReferenceId string `json:"reference_id"`
|
||||||
ImgURL string `json:"img_url"`
|
ImgURL string `json:"img_url"`
|
||||||
|
OrgURL string `json:"org_url"`
|
||||||
Hash string `json:"hash"`
|
Hash string `json:"hash"`
|
||||||
Progress int `json:"progress"`
|
Progress int `json:"progress"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Started bool `json:"started"`
|
|
||||||
}
|
}
|
||||||
|
@ -504,72 +504,72 @@ const socket = ref(null)
|
|||||||
const imgCalls = ref(0)
|
const imgCalls = ref(0)
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
|
|
||||||
const connect = () => {
|
// const connect = () => {
|
||||||
let host = process.env.VUE_APP_WS_HOST
|
// let host = process.env.VUE_APP_WS_HOST
|
||||||
if (host === '') {
|
// if (host === '') {
|
||||||
if (location.protocol === 'https:') {
|
// if (location.protocol === 'https:') {
|
||||||
host = 'wss://' + location.host;
|
// host = 'wss://' + location.host;
|
||||||
} else {
|
// } else {
|
||||||
host = 'ws://' + location.host;
|
// host = 'ws://' + location.host;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
|
// const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
|
||||||
_socket.addEventListener('open', () => {
|
// _socket.addEventListener('open', () => {
|
||||||
socket.value = _socket;
|
// socket.value = _socket;
|
||||||
});
|
// });
|
||||||
|
//
|
||||||
_socket.addEventListener('message', event => {
|
// _socket.addEventListener('message', event => {
|
||||||
if (event.data instanceof Blob) {
|
// if (event.data instanceof Blob) {
|
||||||
const reader = new FileReader();
|
// const reader = new FileReader();
|
||||||
reader.readAsText(event.data, "UTF-8");
|
// reader.readAsText(event.data, "UTF-8");
|
||||||
reader.onload = () => {
|
// reader.onload = () => {
|
||||||
const data = JSON.parse(String(reader.result));
|
// const data = JSON.parse(String(reader.result));
|
||||||
let isNew = true
|
// let isNew = true
|
||||||
if (data.progress === 100) {
|
// if (data.progress === 100) {
|
||||||
for (let i = 0; i < finishedJobs.value.length; i++) {
|
// for (let i = 0; i < finishedJobs.value.length; i++) {
|
||||||
if (finishedJobs.value[i].id === data.id) {
|
// if (finishedJobs.value[i].id === data.id) {
|
||||||
isNew = false
|
// isNew = false
|
||||||
break
|
// break
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
for (let i = 0; i < runningJobs.value.length; i++) {
|
// for (let i = 0; i < runningJobs.value.length; i++) {
|
||||||
if (runningJobs.value[i].id === data.id) {
|
// if (runningJobs.value[i].id === data.id) {
|
||||||
runningJobs.value.splice(i, 1)
|
// runningJobs.value.splice(i, 1)
|
||||||
break
|
// break
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
if (isNew) {
|
// if (isNew) {
|
||||||
finishedJobs.value.unshift(data)
|
// finishedJobs.value.unshift(data)
|
||||||
}
|
// }
|
||||||
} else if (data.progress === -1) { // 任务执行失败
|
// } else if (data.progress === -1) { // 任务执行失败
|
||||||
ElNotification({
|
// ElNotification({
|
||||||
title: '任务执行失败',
|
// title: '任务执行失败',
|
||||||
message: "提示词:" + data['prompt'],
|
// message: "提示词:" + data['prompt'],
|
||||||
type: 'error',
|
// type: 'error',
|
||||||
})
|
// })
|
||||||
runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
|
// runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
|
||||||
|
//
|
||||||
} else {
|
// } else {
|
||||||
for (let i = 0; i < runningJobs.value.length; i++) {
|
// for (let i = 0; i < runningJobs.value.length; i++) {
|
||||||
if (runningJobs.value[i].id === data.id) {
|
// if (runningJobs.value[i].id === data.id) {
|
||||||
isNew = false
|
// isNew = false
|
||||||
runningJobs.value[i] = data
|
// runningJobs.value[i] = data
|
||||||
break
|
// break
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
if (isNew) {
|
// if (isNew) {
|
||||||
runningJobs.value.push(data)
|
// runningJobs.value.push(data)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
});
|
// });
|
||||||
|
//
|
||||||
_socket.addEventListener('close', () => {
|
// _socket.addEventListener('close', () => {
|
||||||
ElMessage.error("Websocket 已经断开,正在重新连接服务器")
|
// ElMessage.error("Websocket 已经断开,正在重新连接服务器")
|
||||||
connect()
|
// connect()
|
||||||
});
|
// });
|
||||||
}
|
// }
|
||||||
|
|
||||||
const translatePrompt = () => {
|
const translatePrompt = () => {
|
||||||
loading.value = true
|
loading.value = true
|
||||||
|
Loading…
Reference in New Issue
Block a user