mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	refactor mj service, add mj service pool support
This commit is contained in:
		@@ -33,7 +33,7 @@ func NewDefaultConfig() *types.AppConfig {
 | 
			
		||||
				BasePath: "./static/upload",
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		MjConfig:     types.MidJourneyConfig{Enabled: false},
 | 
			
		||||
		MjConfigs:    types.MidJourneyConfig{Enabled: false},
 | 
			
		||||
		SdConfig:     types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
 | 
			
		||||
		WeChatBot:    false,
 | 
			
		||||
		AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
 | 
			
		||||
 
 | 
			
		||||
@@ -18,7 +18,7 @@ type AppConfig struct {
 | 
			
		||||
	AesEncryptKey string
 | 
			
		||||
	SmsConfig     AliYunSmsConfig       // AliYun send message service config
 | 
			
		||||
	OSS           OSSConfig             // OSS config
 | 
			
		||||
	MjConfig      MidJourneyConfig      // mj 绘画配置
 | 
			
		||||
	MjConfigs     []MidJourneyConfig    // mj 绘画配置池子
 | 
			
		||||
	WeChatBot     bool                  // 是否启用微信机器人
 | 
			
		||||
	SdConfig      StableDiffusionConfig // sd 绘画配置
 | 
			
		||||
 | 
			
		||||
@@ -116,7 +116,7 @@ type ChatConfig struct {
 | 
			
		||||
	EnableHistory bool   `json:"enable_history"` // 是否允许保存聊天记录
 | 
			
		||||
	ContextDeep   int    `json:"context_deep"`   // 上下文深度
 | 
			
		||||
	DallApiURL    string `json:"dall_api_url"`   // dall-e3 绘图 API 地址
 | 
			
		||||
	DallImgNum int `json:"dall_img_num"` // dall-e3 出图数量
 | 
			
		||||
	DallImgNum    int    `json:"dall_img_num"`   // dall-e3 出图数量
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Platform string
 | 
			
		||||
 
 | 
			
		||||
@@ -11,28 +11,15 @@ const (
 | 
			
		||||
	TaskImage     = TaskType("image")
 | 
			
		||||
	TaskUpscale   = TaskType("upscale")
 | 
			
		||||
	TaskVariation = TaskType("variation")
 | 
			
		||||
	TaskTxt2Img   = TaskType("text2img")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// TaskSrc 任务来源
 | 
			
		||||
type TaskSrc string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TaskSrcChat = TaskSrc("chat") // 来自聊天页面
 | 
			
		||||
	TaskSrcImg  = TaskSrc("img")  // 专业绘画页面
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MjTask MidJourney 任务
 | 
			
		||||
type MjTask struct {
 | 
			
		||||
	Id          int      `json:"id"`
 | 
			
		||||
	SessionId   string   `json:"session_id"`
 | 
			
		||||
	Src         TaskSrc  `json:"src"`
 | 
			
		||||
	Type        TaskType `json:"type"`
 | 
			
		||||
	UserId      int      `json:"user_id"`
 | 
			
		||||
	Prompt      string   `json:"prompt,omitempty"`
 | 
			
		||||
	ChatId      string   `json:"chat_id,omitempty"`
 | 
			
		||||
	RoleId      int      `json:"role_id,omitempty"`
 | 
			
		||||
	Icon        string   `json:"icon,omitempty"`
 | 
			
		||||
	Index       int      `json:"index,omitempty"`
 | 
			
		||||
	MessageId   string   `json:"message_id,omitempty"`
 | 
			
		||||
	MessageHash string   `json:"message_hash,omitempty"`
 | 
			
		||||
@@ -42,7 +29,6 @@ type MjTask struct {
 | 
			
		||||
type SdTask struct {
 | 
			
		||||
	Id         int          `json:"id"` // job 数据库ID
 | 
			
		||||
	SessionId  string       `json:"session_id"`
 | 
			
		||||
	Src        TaskSrc      `json:"src"`
 | 
			
		||||
	Type       TaskType     `json:"type"`
 | 
			
		||||
	UserId     int          `json:"user_id"`
 | 
			
		||||
	Prompt     string       `json:"prompt,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -12,9 +12,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@@ -40,20 +38,6 @@ func NewMidJourneyHandler(
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sessionId := c.Query("session_id")
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.mjService.Clients.Put(sessionId, client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.ClientIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -72,7 +56,7 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
 | 
			
		||||
// Image 创建一个绘画任务
 | 
			
		||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
			
		||||
	if !h.App.Config.MjConfig.Enabled {
 | 
			
		||||
	if !h.App.Config.MjConfigs[0].Enabled {
 | 
			
		||||
		resp.ERROR(c, "MidJourney service is disabled")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										19
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								api/main.go
									
									
									
									
									
								
							@@ -165,23 +165,6 @@ func main() {
 | 
			
		||||
 | 
			
		||||
		// MidJourney 机器人
 | 
			
		||||
		fx.Provide(mj.NewBot),
 | 
			
		||||
		fx.Provide(mj.NewClient),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
 | 
			
		||||
			if config.MjConfig.Enabled {
 | 
			
		||||
				err := bot.Run()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Fatal("MidJourney 服务启动失败:", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) {
 | 
			
		||||
			if config.MjConfig.Enabled {
 | 
			
		||||
				go func() {
 | 
			
		||||
					mjService.Run()
 | 
			
		||||
				}()
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// Stable Diffusion 机器人
 | 
			
		||||
		fx.Provide(sd.NewService),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
 | 
			
		||||
@@ -256,13 +239,11 @@ func main() {
 | 
			
		||||
			group.POST("upscale", h.Upscale)
 | 
			
		||||
			group.POST("variation", h.Variation)
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/sd")
 | 
			
		||||
			group.POST("image", h.Image)
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 管理后台控制器
 | 
			
		||||
 
 | 
			
		||||
@@ -23,7 +23,7 @@ type Bot struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
 | 
			
		||||
	discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
 | 
			
		||||
	discord, err := discordgo.New("Bot " + config.MjConfigs.BotToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -41,7 +41,7 @@ func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Bot{
 | 
			
		||||
		config:  &config.MjConfig,
 | 
			
		||||
		config:  &config.MjConfigs,
 | 
			
		||||
		bot:     discord,
 | 
			
		||||
		service: service,
 | 
			
		||||
	}, nil
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -14,13 +15,13 @@ type Client struct {
 | 
			
		||||
	config *types.MidJourneyConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClient(config *types.AppConfig) *Client {
 | 
			
		||||
func NewClient(config *types.MidJourneyConfig, proxy string) *Client {
 | 
			
		||||
	client := req.C().SetTimeout(10 * time.Second)
 | 
			
		||||
	// set proxy URL
 | 
			
		||||
	if config.ProxyURL != "" {
 | 
			
		||||
		client.SetProxyURL(config.ProxyURL)
 | 
			
		||||
	if utils.IsEmptyValue(proxy) {
 | 
			
		||||
		client.SetProxyURL(proxy)
 | 
			
		||||
	}
 | 
			
		||||
	return &Client{client: client, config: &config.MjConfig}
 | 
			
		||||
	return &Client{client: client, config: config}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) Imagine(prompt string) error {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										38
									
								
								api/service/mj/pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								api/service/mj/pool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ServicePool Mj service pool
 | 
			
		||||
type ServicePool struct {
 | 
			
		||||
	services  []Service
 | 
			
		||||
	taskQueue *store.RedisQueue
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
 | 
			
		||||
	// create mj client and service
 | 
			
		||||
	for _, config := range appConfig.MjConfigs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// create mj client
 | 
			
		||||
		client := NewClient(&config, appConfig.ProxyURL)
 | 
			
		||||
 | 
			
		||||
		// create mj service
 | 
			
		||||
		service := NewService()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &ServicePool{
 | 
			
		||||
		taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) PushTask(task types.MjTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	p.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
@@ -2,63 +2,63 @@ package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MJ 绘画服务
 | 
			
		||||
 | 
			
		||||
const RunningJobKey = "MidJourney_Running_Job"
 | 
			
		||||
 | 
			
		||||
// Service MJ 绘画服务
 | 
			
		||||
type Service struct {
 | 
			
		||||
	client        *Client // MJ 客户端
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	redis         *redis.Client
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	Clients       *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
 | 
			
		||||
	ChatClients   *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
 | 
			
		||||
	proxyURL      string
 | 
			
		||||
	name             string  // service name
 | 
			
		||||
	client           *Client // MJ client
 | 
			
		||||
	taskQueue        *store.RedisQueue
 | 
			
		||||
	db               *gorm.DB
 | 
			
		||||
	uploadManager    *oss.UploaderManager
 | 
			
		||||
	proxyURL         string
 | 
			
		||||
	maxHandleTaskNum int32             // max task number current service can handle
 | 
			
		||||
	handledTaskNum   int32             // already handled task number
 | 
			
		||||
	taskStartTimes   map[int]time.Time // task start time, to check if the task is timeout
 | 
			
		||||
	taskTimeout      int64
 | 
			
		||||
	snowflake        *service.Snowflake
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
 | 
			
		||||
func NewService(name string, queue *store.RedisQueue, timeout int64, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		redis:         redisCli,
 | 
			
		||||
		db:            db,
 | 
			
		||||
		taskQueue:     store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
 | 
			
		||||
		client:        client,
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		Clients:       types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		proxyURL:      config.ProxyURL,
 | 
			
		||||
		name:           name,
 | 
			
		||||
		db:             db,
 | 
			
		||||
		taskQueue:      queue,
 | 
			
		||||
		client:         client,
 | 
			
		||||
		uploadManager:  manager,
 | 
			
		||||
		taskTimeout:    timeout,
 | 
			
		||||
		proxyURL:       config.ProxyURL,
 | 
			
		||||
		taskStartTimes: make(map[int]time.Time, 0),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Info("Starting MidJourney job consumer.")
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	logger.Infof("Starting MidJourney job consumer for %s", s.name)
 | 
			
		||||
	for {
 | 
			
		||||
		_, err := s.redis.Get(ctx, RunningJobKey).Result()
 | 
			
		||||
		if err == nil { // 队列串行执行
 | 
			
		||||
		s.checkTasks()
 | 
			
		||||
		if !s.canHandleTask() {
 | 
			
		||||
			// current service is full, can not handle more task
 | 
			
		||||
			// waiting for running task finish
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var task types.MjTask
 | 
			
		||||
		err = s.taskQueue.LPop(&task)
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("Consuming Task: %+v", task)
 | 
			
		||||
 | 
			
		||||
		logger.Infof("handle a new MidJourney task: %+v", task)
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
			err = s.client.Imagine(task.Prompt)
 | 
			
		||||
@@ -70,50 +70,40 @@ func (s *Service) Run() {
 | 
			
		||||
		case types.TaskVariation:
 | 
			
		||||
			err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err)
 | 
			
		||||
			// 删除任务
 | 
			
		||||
			s.db.Delete(&model.MidJourneyJob{Id: uint(task.Id)})
 | 
			
		||||
			// 推送任务到前端
 | 
			
		||||
			client := s.Clients.Get(task.SessionId)
 | 
			
		||||
			if client != nil {
 | 
			
		||||
				utils.ReplyChunkMessage(client, vo.MidJourneyJob{
 | 
			
		||||
					Type:      task.Type.String(),
 | 
			
		||||
					UserId:    task.UserId,
 | 
			
		||||
					MessageId: task.MessageId,
 | 
			
		||||
					Progress:  -1,
 | 
			
		||||
					Prompt:    task.Prompt,
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
			// update the task progress
 | 
			
		||||
			s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
 | 
			
		||||
			atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 更新任务的执行状态
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
 | 
			
		||||
		// 锁定任务执行通道,直到任务超时(5分钟)
 | 
			
		||||
		s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
 | 
			
		||||
		// lock the task until the execute timeout
 | 
			
		||||
		s.taskStartTimes[task.Id] = time.Now()
 | 
			
		||||
		atomic.AddInt32(&s.handledTaskNum, 1)
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) PushTask(task types.MjTask) {
 | 
			
		||||
	logger.Infof("add a new MidJourney Task: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
// check if current service instance can handle more task
 | 
			
		||||
func (s *Service) canHandleTask() bool {
 | 
			
		||||
	handledNum := atomic.LoadInt32(&s.handledTaskNum)
 | 
			
		||||
	return handledNum < s.maxHandleTaskNum
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) checkTasks() {
 | 
			
		||||
	for k, t := range s.taskStartTimes {
 | 
			
		||||
		if time.Now().Unix()-t.Unix() > s.taskTimeout {
 | 
			
		||||
			delete(s.taskStartTimes, k)
 | 
			
		||||
			atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Notify(data CBReq) {
 | 
			
		||||
	taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
 | 
			
		||||
	if err != nil { // 过期任务,丢弃
 | 
			
		||||
		logger.Warn("任务已过期:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var task types.MjTask
 | 
			
		||||
	err = utils.JsonDecode(taskString, &task)
 | 
			
		||||
	if err != nil { // 非标准任务,丢弃
 | 
			
		||||
		logger.Warn("任务解析失败:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// extract the task ID
 | 
			
		||||
	split := strings.Split(data.Prompt, " ")
 | 
			
		||||
	var job model.MidJourneyJob
 | 
			
		||||
	res := s.db.Where("message_id = ?", data.MessageId).First(&job)
 | 
			
		||||
	if res.Error == nil && data.Status == Finished {
 | 
			
		||||
@@ -121,137 +111,37 @@ func (s *Service) Notify(data CBReq) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if task.Src == types.TaskSrcImg { // 绘画任务
 | 
			
		||||
		var job model.MidJourneyJob
 | 
			
		||||
		res := s.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Warn("非法任务:", res.Error)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		job.MessageId = data.MessageId
 | 
			
		||||
		job.ReferenceId = data.ReferenceId
 | 
			
		||||
		job.Progress = data.Progress
 | 
			
		||||
		job.Prompt = data.Prompt
 | 
			
		||||
		job.Hash = data.Image.Hash
 | 
			
		||||
	res = s.db.Where("task_id = ?", split[0]).First(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Warn("非法任务:", res.Error)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	job.MessageId = data.MessageId
 | 
			
		||||
	job.ReferenceId = data.ReferenceId
 | 
			
		||||
	job.Progress = data.Progress
 | 
			
		||||
	job.Prompt = data.Prompt
 | 
			
		||||
	job.Hash = data.Image.Hash
 | 
			
		||||
	job.OrgURL = data.Image.URL // save origin image
 | 
			
		||||
 | 
			
		||||
		// 任务完成,将最终的图片下载下来
 | 
			
		||||
		if data.Progress == 100 {
 | 
			
		||||
			imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download img: ", err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			job.ImgURL = imgURL
 | 
			
		||||
		} else {
 | 
			
		||||
			// 临时图片直接保存,访问的时候使用代理进行转发
 | 
			
		||||
			job.ImgURL = data.Image.URL
 | 
			
		||||
		}
 | 
			
		||||
		res = s.db.Updates(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update job: ", res.Error)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	// upload image
 | 
			
		||||
	imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("error with download img: ", err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	job.ImgURL = imgURL
 | 
			
		||||
 | 
			
		||||
		var jobVo vo.MidJourneyJob
 | 
			
		||||
		err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			if data.Progress < 100 {
 | 
			
		||||
				image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 推送任务到前端
 | 
			
		||||
			client := s.Clients.Get(task.SessionId)
 | 
			
		||||
			if client != nil {
 | 
			
		||||
				utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	} else if task.Src == types.TaskSrcChat { // 聊天任务
 | 
			
		||||
		wsClient := s.ChatClients.Get(task.SessionId)
 | 
			
		||||
		if data.Status == Finished {
 | 
			
		||||
			if wsClient != nil && data.ReferenceId != "" {
 | 
			
		||||
				content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
 | 
			
		||||
				utils.ReplyMessage(wsClient, content)
 | 
			
		||||
			}
 | 
			
		||||
			// download image
 | 
			
		||||
			imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
				if wsClient != nil && data.ReferenceId != "" {
 | 
			
		||||
					content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
 | 
			
		||||
					utils.ReplyMessage(wsClient, content)
 | 
			
		||||
				}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			tx := s.db.Begin()
 | 
			
		||||
			data.Image.URL = imgURL
 | 
			
		||||
			message := model.HistoryMessage{
 | 
			
		||||
				UserId:     uint(task.UserId),
 | 
			
		||||
				ChatId:     task.ChatId,
 | 
			
		||||
				RoleId:     uint(task.RoleId),
 | 
			
		||||
				Type:       types.MjMsg,
 | 
			
		||||
				Icon:       task.Icon,
 | 
			
		||||
				Content:    utils.JsonEncode(data),
 | 
			
		||||
				Tokens:     0,
 | 
			
		||||
				UseContext: false,
 | 
			
		||||
			}
 | 
			
		||||
			res = tx.Create(&message)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("error with update database: ", err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// save the job
 | 
			
		||||
			job.UserId = task.UserId
 | 
			
		||||
			job.Type = task.Type.String()
 | 
			
		||||
			job.MessageId = data.MessageId
 | 
			
		||||
			job.ReferenceId = data.ReferenceId
 | 
			
		||||
			job.Prompt = data.Prompt
 | 
			
		||||
			job.ImgURL = imgURL
 | 
			
		||||
			job.Progress = data.Progress
 | 
			
		||||
			job.Hash = data.Image.Hash
 | 
			
		||||
			job.CreatedAt = time.Now()
 | 
			
		||||
			res = tx.Create(&job)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("error with update database: ", err)
 | 
			
		||||
				tx.Rollback()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			tx.Commit()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if wsClient == nil { // 客户端断线,则丢弃
 | 
			
		||||
			logger.Errorf("Client is offline: %+v", data)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if data.Status == Finished {
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
			// 本次绘画完毕,移除客户端
 | 
			
		||||
			s.ChatClients.Delete(task.SessionId)
 | 
			
		||||
		} else {
 | 
			
		||||
			// 使用代理临时转发图片
 | 
			
		||||
			if data.Image.URL != "" {
 | 
			
		||||
				image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
			
		||||
		}
 | 
			
		||||
	res = s.db.Updates(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update job: ", res.Error)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户剩余绘图次数
 | 
			
		||||
	// TODO: 放大图片是否需要消耗绘图次数?
 | 
			
		||||
	if data.Status == Finished {
 | 
			
		||||
		s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
		// 解除任务锁定
 | 
			
		||||
		s.redis.Del(context.Background(), RunningJobKey)
 | 
			
		||||
		// update user's img calls
 | 
			
		||||
		s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
		// release lock task
 | 
			
		||||
		atomic.AddInt32(&s.handledTaskNum, -1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -55,7 +55,7 @@ func (s *HuPiPayService) Sign(params map[string]string) string {
 | 
			
		||||
	var data string
 | 
			
		||||
	keys := make([]string, 0, 0)
 | 
			
		||||
	params["appid"] = s.appId
 | 
			
		||||
	for key, _ := range params {
 | 
			
		||||
	for key := range params {
 | 
			
		||||
		keys = append(keys, key)
 | 
			
		||||
	}
 | 
			
		||||
	sort.Strings(keys)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,13 +6,14 @@ type MidJourneyJob struct {
 | 
			
		||||
	Id          uint `gorm:"primarykey;column:id"`
 | 
			
		||||
	Type        string
 | 
			
		||||
	UserId      int
 | 
			
		||||
	TaskId      string
 | 
			
		||||
	MessageId   string
 | 
			
		||||
	ReferenceId string
 | 
			
		||||
	ImgURL      string
 | 
			
		||||
	OrgURL      string // 原图地址
 | 
			
		||||
	Hash        string // message hash
 | 
			
		||||
	Progress    int
 | 
			
		||||
	Prompt      string
 | 
			
		||||
	Started     bool
 | 
			
		||||
	CreatedAt   time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -9,9 +9,9 @@ type MidJourneyJob struct {
 | 
			
		||||
	MessageId   string    `json:"message_id"`
 | 
			
		||||
	ReferenceId string    `json:"reference_id"`
 | 
			
		||||
	ImgURL      string    `json:"img_url"`
 | 
			
		||||
	OrgURL      string    `json:"org_url"`
 | 
			
		||||
	Hash        string    `json:"hash"`
 | 
			
		||||
	Progress    int       `json:"progress"`
 | 
			
		||||
	Prompt      string    `json:"prompt"`
 | 
			
		||||
	CreatedAt   time.Time `json:"created_at"`
 | 
			
		||||
	Started     bool      `json:"started"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -504,72 +504,72 @@ const socket = ref(null)
 | 
			
		||||
const imgCalls = ref(0)
 | 
			
		||||
const loading = ref(false)
 | 
			
		||||
 | 
			
		||||
const connect = () => {
 | 
			
		||||
  let host = process.env.VUE_APP_WS_HOST
 | 
			
		||||
  if (host === '') {
 | 
			
		||||
    if (location.protocol === 'https:') {
 | 
			
		||||
      host = 'wss://' + location.host;
 | 
			
		||||
    } else {
 | 
			
		||||
      host = 'ws://' + location.host;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
 | 
			
		||||
  _socket.addEventListener('open', () => {
 | 
			
		||||
    socket.value = _socket;
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  _socket.addEventListener('message', event => {
 | 
			
		||||
    if (event.data instanceof Blob) {
 | 
			
		||||
      const reader = new FileReader();
 | 
			
		||||
      reader.readAsText(event.data, "UTF-8");
 | 
			
		||||
      reader.onload = () => {
 | 
			
		||||
        const data = JSON.parse(String(reader.result));
 | 
			
		||||
        let isNew = true
 | 
			
		||||
        if (data.progress === 100) {
 | 
			
		||||
          for (let i = 0; i < finishedJobs.value.length; i++) {
 | 
			
		||||
            if (finishedJobs.value[i].id === data.id) {
 | 
			
		||||
              isNew = false
 | 
			
		||||
              break
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
          for (let i = 0; i < runningJobs.value.length; i++) {
 | 
			
		||||
            if (runningJobs.value[i].id === data.id) {
 | 
			
		||||
              runningJobs.value.splice(i, 1)
 | 
			
		||||
              break
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
          if (isNew) {
 | 
			
		||||
            finishedJobs.value.unshift(data)
 | 
			
		||||
          }
 | 
			
		||||
        } else if (data.progress === -1) { // 任务执行失败
 | 
			
		||||
          ElNotification({
 | 
			
		||||
            title: '任务执行失败',
 | 
			
		||||
            message: "提示词:" + data['prompt'],
 | 
			
		||||
            type: 'error',
 | 
			
		||||
          })
 | 
			
		||||
          runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
 | 
			
		||||
 | 
			
		||||
        } else {
 | 
			
		||||
          for (let i = 0; i < runningJobs.value.length; i++) {
 | 
			
		||||
            if (runningJobs.value[i].id === data.id) {
 | 
			
		||||
              isNew = false
 | 
			
		||||
              runningJobs.value[i] = data
 | 
			
		||||
              break
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
          if (isNew) {
 | 
			
		||||
            runningJobs.value.push(data)
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  _socket.addEventListener('close', () => {
 | 
			
		||||
    ElMessage.error("Websocket 已经断开,正在重新连接服务器")
 | 
			
		||||
    connect()
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
// const connect = () => {
 | 
			
		||||
//   let host = process.env.VUE_APP_WS_HOST
 | 
			
		||||
//   if (host === '') {
 | 
			
		||||
//     if (location.protocol === 'https:') {
 | 
			
		||||
//       host = 'wss://' + location.host;
 | 
			
		||||
//     } else {
 | 
			
		||||
//       host = 'ws://' + location.host;
 | 
			
		||||
//     }
 | 
			
		||||
//   }
 | 
			
		||||
//   const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
 | 
			
		||||
//   _socket.addEventListener('open', () => {
 | 
			
		||||
//     socket.value = _socket;
 | 
			
		||||
//   });
 | 
			
		||||
//
 | 
			
		||||
//   _socket.addEventListener('message', event => {
 | 
			
		||||
//     if (event.data instanceof Blob) {
 | 
			
		||||
//       const reader = new FileReader();
 | 
			
		||||
//       reader.readAsText(event.data, "UTF-8");
 | 
			
		||||
//       reader.onload = () => {
 | 
			
		||||
//         const data = JSON.parse(String(reader.result));
 | 
			
		||||
//         let isNew = true
 | 
			
		||||
//         if (data.progress === 100) {
 | 
			
		||||
//           for (let i = 0; i < finishedJobs.value.length; i++) {
 | 
			
		||||
//             if (finishedJobs.value[i].id === data.id) {
 | 
			
		||||
//               isNew = false
 | 
			
		||||
//               break
 | 
			
		||||
//             }
 | 
			
		||||
//           }
 | 
			
		||||
//           for (let i = 0; i < runningJobs.value.length; i++) {
 | 
			
		||||
//             if (runningJobs.value[i].id === data.id) {
 | 
			
		||||
//               runningJobs.value.splice(i, 1)
 | 
			
		||||
//               break
 | 
			
		||||
//             }
 | 
			
		||||
//           }
 | 
			
		||||
//           if (isNew) {
 | 
			
		||||
//             finishedJobs.value.unshift(data)
 | 
			
		||||
//           }
 | 
			
		||||
//         } else if (data.progress === -1) { // 任务执行失败
 | 
			
		||||
//           ElNotification({
 | 
			
		||||
//             title: '任务执行失败',
 | 
			
		||||
//             message: "提示词:" + data['prompt'],
 | 
			
		||||
//             type: 'error',
 | 
			
		||||
//           })
 | 
			
		||||
//           runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id)
 | 
			
		||||
//
 | 
			
		||||
//         } else {
 | 
			
		||||
//           for (let i = 0; i < runningJobs.value.length; i++) {
 | 
			
		||||
//             if (runningJobs.value[i].id === data.id) {
 | 
			
		||||
//               isNew = false
 | 
			
		||||
//               runningJobs.value[i] = data
 | 
			
		||||
//               break
 | 
			
		||||
//             }
 | 
			
		||||
//           }
 | 
			
		||||
//           if (isNew) {
 | 
			
		||||
//             runningJobs.value.push(data)
 | 
			
		||||
//           }
 | 
			
		||||
//         }
 | 
			
		||||
//       }
 | 
			
		||||
//     }
 | 
			
		||||
//   });
 | 
			
		||||
//
 | 
			
		||||
//   _socket.addEventListener('close', () => {
 | 
			
		||||
//     ElMessage.error("Websocket 已经断开,正在重新连接服务器")
 | 
			
		||||
//     connect()
 | 
			
		||||
//   });
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
const translatePrompt = () => {
 | 
			
		||||
  loading.value = true
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user