mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: add midjourney message receive handler
This commit is contained in:
		@@ -34,7 +34,7 @@ type AppServer struct {
 | 
				
			|||||||
	ChatClients   *types.LMap[string, *types.WsClient]    // map[sessionId]Websocket 连接集合
 | 
						ChatClients   *types.LMap[string, *types.WsClient]    // map[sessionId]Websocket 连接集合
 | 
				
			||||||
	ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
 | 
						ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
 | 
				
			||||||
	Functions     map[string]function.Function
 | 
						Functions     map[string]function.Function
 | 
				
			||||||
	MjTasks       *types.LMap[string, types.MjTask]
 | 
						MjTaskClients *types.LMap[string, *types.WsClient]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
 | 
					func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
 | 
				
			||||||
@@ -48,7 +48,7 @@ func NewServer(appConfig *types.AppConfig, functions map[string]function.Functio
 | 
				
			|||||||
		ChatSession:   types.NewLMap[string, types.ChatSession](),
 | 
							ChatSession:   types.NewLMap[string, types.ChatSession](),
 | 
				
			||||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
							ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
				
			||||||
		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
							ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
				
			||||||
		MjTasks:       types.NewLMap[string, types.MjTask](),
 | 
							MjTaskClients: types.NewLMap[string, *types.WsClient](),
 | 
				
			||||||
		Functions:     functions,
 | 
							Functions:     functions,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,8 +12,8 @@ type BizVo struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// WsMessage Websocket message
 | 
					// WsMessage Websocket message
 | 
				
			||||||
type WsMessage struct {
 | 
					type WsMessage struct {
 | 
				
			||||||
	Type    WsMsgType `json:"type"` // 消息类别,start, end
 | 
						Type    WsMsgType   `json:"type"` // 消息类别,start, end, img
 | 
				
			||||||
	Content string    `json:"content"`
 | 
						Content interface{} `json:"content"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
type WsMsgType string
 | 
					type WsMsgType string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -21,6 +21,7 @@ const (
 | 
				
			|||||||
	WsStart  = WsMsgType("start")
 | 
						WsStart  = WsMsgType("start")
 | 
				
			||||||
	WsMiddle = WsMsgType("middle")
 | 
						WsMiddle = WsMsgType("middle")
 | 
				
			||||||
	WsEnd    = WsMsgType("end")
 | 
						WsEnd    = WsMsgType("end")
 | 
				
			||||||
 | 
						WsImg    = WsMsgType("img")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type BizCode int
 | 
					type BizCode int
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,6 +5,7 @@ import (
 | 
				
			|||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"chatplus/core"
 | 
						"chatplus/core"
 | 
				
			||||||
	"chatplus/core/types"
 | 
						"chatplus/core/types"
 | 
				
			||||||
 | 
						"chatplus/store"
 | 
				
			||||||
	"chatplus/store/model"
 | 
						"chatplus/store/model"
 | 
				
			||||||
	"chatplus/store/vo"
 | 
						"chatplus/store/vo"
 | 
				
			||||||
	"chatplus/utils"
 | 
						"chatplus/utils"
 | 
				
			||||||
@@ -26,14 +27,16 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
 | 
					const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
 | 
				
			||||||
 | 
					const TaskStorePrefix = "/tasks/"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ChatHandler struct {
 | 
					type ChatHandler struct {
 | 
				
			||||||
	BaseHandler
 | 
						BaseHandler
 | 
				
			||||||
	db *gorm.DB
 | 
						db      *gorm.DB
 | 
				
			||||||
 | 
						leveldb *store.LevelDB
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
 | 
					func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB) *ChatHandler {
 | 
				
			||||||
	handler := ChatHandler{db: db}
 | 
						handler := ChatHandler{db: db, leveldb: levelDB}
 | 
				
			||||||
	handler.App = app
 | 
						handler.App = app
 | 
				
			||||||
	return &handler
 | 
						return &handler
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -133,7 +136,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
 | 
					// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
 | 
				
			||||||
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error {
 | 
					func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
 | 
				
			||||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
						promptCreatedAt := time.Now() // 记录提问时间
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var user model.User
 | 
						var user model.User
 | 
				
			||||||
@@ -340,13 +343,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
 | 
				
			|||||||
					if functionName == types.FuncMidJourney {
 | 
										if functionName == types.FuncMidJourney {
 | 
				
			||||||
						key := utils.Sha256(data)
 | 
											key := utils.Sha256(data)
 | 
				
			||||||
						// add task for MidJourney
 | 
											// add task for MidJourney
 | 
				
			||||||
						h.App.MjTasks.Put(key, types.MjTask{
 | 
											h.App.MjTaskClients.Put(key, ws)
 | 
				
			||||||
 | 
											task := types.MjTask{
 | 
				
			||||||
							UserId: userVo.Id,
 | 
												UserId: userVo.Id,
 | 
				
			||||||
							RoleId: role.Id,
 | 
												RoleId: role.Id,
 | 
				
			||||||
							Icon:   role.Icon,
 | 
												Icon:   role.Icon,
 | 
				
			||||||
							Client: ws,
 | 
												Client: ws,
 | 
				
			||||||
							ChatId: session.ChatId,
 | 
												ChatId: session.ChatId,
 | 
				
			||||||
						})
 | 
											}
 | 
				
			||||||
 | 
											err := h.leveldb.Put(TaskStorePrefix+key, task)
 | 
				
			||||||
 | 
											if err != nil {
 | 
				
			||||||
 | 
												logger.Error("error with store MidJourney task: ", err)
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
											content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -24,6 +24,7 @@ type Image struct {
 | 
				
			|||||||
	Width    int    `json:"width"`
 | 
						Width    int    `json:"width"`
 | 
				
			||||||
	Height   int    `json:"height"`
 | 
						Height   int    `json:"height"`
 | 
				
			||||||
	Size     int    `json:"size"`
 | 
						Size     int    `json:"size"`
 | 
				
			||||||
 | 
						Hash     string `json:"hash"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type MidJourneyHandler struct {
 | 
					type MidJourneyHandler struct {
 | 
				
			||||||
@@ -44,18 +45,30 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var data struct {
 | 
						var data struct {
 | 
				
			||||||
		Image   Image      `json:"image"`
 | 
							Type      string     `json:"type"`
 | 
				
			||||||
		Content string     `json:"content"`
 | 
							MessageId string     `json:"message_id"`
 | 
				
			||||||
		Status  TaskStatus `json:"status"`
 | 
							Image     Image      `json:"image"`
 | 
				
			||||||
 | 
							Content   string     `json:"content"`
 | 
				
			||||||
 | 
							Prompt    string     `json:"prompt"`
 | 
				
			||||||
 | 
							Status    TaskStatus `json:"status"`
 | 
				
			||||||
 | 
							Key       string     `json:"key"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
						if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
 | 
				
			||||||
		resp.ERROR(c, types.InvalidArgs)
 | 
							resp.ERROR(c, types.InvalidArgs)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						key := utils.Sha256(data.Prompt)
 | 
				
			||||||
 | 
						data.Key = key
 | 
				
			||||||
 | 
						// TODO: 如果绘画任务完成了则将该消息保存到当前会话的聊天历史记录
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sessionId := "u7blnft9zqisyrwidjb22j6b78iqc30lv9jtud3k9o"
 | 
						wsClient := h.App.MjTaskClients.Get(key)
 | 
				
			||||||
	wsClient := h.App.ChatClients.Get(sessionId)
 | 
						if wsClient == nil { // 客户端断线,则丢弃
 | 
				
			||||||
	utils.ReplyMessage(wsClient, "")
 | 
							resp.SUCCESS(c)
 | 
				
			||||||
	logger.Infof("Data: %+v", data)
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 推送消息到客户端
 | 
				
			||||||
 | 
						// TODO: 增加绘画消息类型
 | 
				
			||||||
 | 
						utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsImg, Content: data})
 | 
				
			||||||
	resp.ERROR(c, "Error with CallBack")
 | 
						resp.ERROR(c, "Error with CallBack")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,13 +14,13 @@ const CodeStorePrefix = "/verify/codes/"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type SmsHandler struct {
 | 
					type SmsHandler struct {
 | 
				
			||||||
	BaseHandler
 | 
						BaseHandler
 | 
				
			||||||
	db      *store.LevelDB
 | 
						leveldb *store.LevelDB
 | 
				
			||||||
	sms     *service.AliYunSmsService
 | 
						sms     *service.AliYunSmsService
 | 
				
			||||||
	captcha *service.CaptchaService
 | 
						captcha *service.CaptchaService
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewSmsHandler(app *core.AppServer, db *store.LevelDB, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler {
 | 
					func NewSmsHandler(app *core.AppServer, db *store.LevelDB, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler {
 | 
				
			||||||
	handler := &SmsHandler{db: db, sms: sms, captcha: captcha}
 | 
						handler := &SmsHandler{leveldb: db, sms: sms, captcha: captcha}
 | 
				
			||||||
	handler.App = app
 | 
						handler.App = app
 | 
				
			||||||
	return handler
 | 
						return handler
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -50,7 +50,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 存储验证码,等待后面注册验证
 | 
						// 存储验证码,等待后面注册验证
 | 
				
			||||||
	err = h.db.Put(CodeStorePrefix+data.Mobile, code)
 | 
						err = h.leveldb.Put(CodeStorePrefix+data.Mobile, code)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		resp.ERROR(c, "验证码保存失败")
 | 
							resp.ERROR(c, "验证码保存失败")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -22,11 +22,11 @@ type UserHandler struct {
 | 
				
			|||||||
	BaseHandler
 | 
						BaseHandler
 | 
				
			||||||
	db       *gorm.DB
 | 
						db       *gorm.DB
 | 
				
			||||||
	searcher *xdb.Searcher
 | 
						searcher *xdb.Searcher
 | 
				
			||||||
	levelDB  *store.LevelDB
 | 
						leveldb  *store.LevelDB
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher, levelDB *store.LevelDB) *UserHandler {
 | 
					func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher, levelDB *store.LevelDB) *UserHandler {
 | 
				
			||||||
	handler := &UserHandler{db: db, searcher: searcher, levelDB: levelDB}
 | 
						handler := &UserHandler{db: db, searcher: searcher, leveldb: levelDB}
 | 
				
			||||||
	handler.App = app
 | 
						handler.App = app
 | 
				
			||||||
	return handler
 | 
						return handler
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -60,7 +60,7 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
				
			|||||||
	key := CodeStorePrefix + data.Mobile
 | 
						key := CodeStorePrefix + data.Mobile
 | 
				
			||||||
	if h.App.SysConfig.EnabledMsgService {
 | 
						if h.App.SysConfig.EnabledMsgService {
 | 
				
			||||||
		var code int
 | 
							var code int
 | 
				
			||||||
		err := h.levelDB.Get(key, &code)
 | 
							err := h.leveldb.Get(key, &code)
 | 
				
			||||||
		if err != nil || code != data.Code {
 | 
							if err != nil || code != data.Code {
 | 
				
			||||||
			logger.Info(code)
 | 
								logger.Info(code)
 | 
				
			||||||
			resp.ERROR(c, "短信验证码错误")
 | 
								resp.ERROR(c, "短信验证码错误")
 | 
				
			||||||
@@ -118,7 +118,7 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if h.App.SysConfig.EnabledMsgService {
 | 
						if h.App.SysConfig.EnabledMsgService {
 | 
				
			||||||
		_ = h.levelDB.Delete(key) // 注册成功,删除短信验证码
 | 
							_ = h.leveldb.Delete(key) // 注册成功,删除短信验证码
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	resp.SUCCESS(c, user)
 | 
						resp.SUCCESS(c, user)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -366,7 +366,7 @@ func (h *UserHandler) BindMobile(c *gin.Context) {
 | 
				
			|||||||
	// 检查验证码
 | 
						// 检查验证码
 | 
				
			||||||
	key := CodeStorePrefix + data.Mobile
 | 
						key := CodeStorePrefix + data.Mobile
 | 
				
			||||||
	var code int
 | 
						var code int
 | 
				
			||||||
	err := h.levelDB.Get(key, &code)
 | 
						err := h.leveldb.Get(key, &code)
 | 
				
			||||||
	if err != nil || code != data.Code {
 | 
						if err != nil || code != data.Code {
 | 
				
			||||||
		resp.ERROR(c, "短信验证码错误")
 | 
							resp.ERROR(c, "短信验证码错误")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
@@ -384,6 +384,6 @@ func (h *UserHandler) BindMobile(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_ = h.levelDB.Delete(key) // 删除短信验证码
 | 
						_ = h.leveldb.Delete(key) // 删除短信验证码
 | 
				
			||||||
	resp.SUCCESS(c)
 | 
						resp.SUCCESS(c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -36,12 +36,13 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
 | 
				
			|||||||
		delete(params, "ar")
 | 
							delete(params, "ar")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	prompt = prompt + " --niji 5"
 | 
						prompt = prompt + " --niji 5"
 | 
				
			||||||
 | 
						url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL)
 | 
				
			||||||
	var res types.BizVo
 | 
						var res types.BizVo
 | 
				
			||||||
	r, err := f.client.R().
 | 
						r, err := f.client.R().
 | 
				
			||||||
		SetHeader("Authorization", f.config.Token).
 | 
							SetHeader("Authorization", f.config.Token).
 | 
				
			||||||
		SetHeader("Content-Type", "application/json").
 | 
							SetHeader("Content-Type", "application/json").
 | 
				
			||||||
		SetBody(params).
 | 
							SetBody(params).
 | 
				
			||||||
		SetSuccessResult(&res).Post(f.config.ApiURL)
 | 
							SetSuccessResult(&res).Post(url)
 | 
				
			||||||
	if err != nil || r.IsErrorState() {
 | 
						if err != nil || r.IsErrorState() {
 | 
				
			||||||
		return "", fmt.Errorf("%v%v", r.String(), err)
 | 
							return "", fmt.Errorf("%v%v", r.String(), err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,20 +9,20 @@ import (
 | 
				
			|||||||
var logger = logger2.GetLogger()
 | 
					var logger = logger2.GetLogger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ReplyChunkMessage 回复客户片段端消息
 | 
					// ReplyChunkMessage 回复客户片段端消息
 | 
				
			||||||
func ReplyChunkMessage(client types.Client, message types.WsMessage) {
 | 
					func ReplyChunkMessage(client *types.WsClient, message types.WsMessage) {
 | 
				
			||||||
	msg, err := json.Marshal(message)
 | 
						msg, err := json.Marshal(message)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Errorf("Error for decoding json data: %v", err.Error())
 | 
							logger.Errorf("Error for decoding json data: %v", err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = client.(*types.WsClient).Send(msg)
 | 
						err = client.Send(msg)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Errorf("Error for reply message: %v", err.Error())
 | 
							logger.Errorf("Error for reply message: %v", err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ReplyMessage 回复客户端一条完整的消息
 | 
					// ReplyMessage 回复客户端一条完整的消息
 | 
				
			||||||
func ReplyMessage(ws types.Client, message string) {
 | 
					func ReplyMessage(ws *types.WsClient, message interface{}) {
 | 
				
			||||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
						ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
				
			||||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
 | 
						ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
 | 
				
			||||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
 | 
						ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user