From 6e40f92aafcf2677c6bd19760b27106ca24f1275 Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 14 Aug 2023 07:09:52 +0800 Subject: [PATCH] feat: add midjourney message receive handler --- api/core/app_server.go | 4 ++-- api/core/types/web.go | 5 +++-- api/handler/chat_handler.go | 20 ++++++++++++++------ api/handler/mj_handler.go | 29 +++++++++++++++++++++-------- api/handler/sms_handler.go | 6 +++--- api/handler/user_handler.go | 12 ++++++------ api/service/function/mid_journey.go | 3 ++- api/utils/websocket.go | 6 +++--- 8 files changed, 54 insertions(+), 31 deletions(-) diff --git a/api/core/app_server.go b/api/core/app_server.go index 4f0b0c55..85d1db31 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -34,7 +34,7 @@ type AppServer struct { ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合 ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function Functions map[string]function.Function - MjTasks *types.LMap[string, types.MjTask] + MjTaskClients *types.LMap[string, *types.WsClient] } func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer { @@ -48,7 +48,7 @@ func NewServer(appConfig *types.AppConfig, functions map[string]function.Functio ChatSession: types.NewLMap[string, types.ChatSession](), ChatClients: types.NewLMap[string, *types.WsClient](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), - MjTasks: types.NewLMap[string, types.MjTask](), + MjTaskClients: types.NewLMap[string, *types.WsClient](), Functions: functions, } } diff --git a/api/core/types/web.go b/api/core/types/web.go index 48f0a26b..d100e592 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -12,8 +12,8 @@ type BizVo struct { // WsMessage Websocket message type WsMessage struct { - Type WsMsgType `json:"type"` // 消息类别,start, end - Content string `json:"content"` + Type WsMsgType `json:"type"` // 消息类别,start, end, img + Content interface{} `json:"content"` } type WsMsgType string @@ -21,6 +21,7 @@ const ( WsStart = WsMsgType("start") WsMiddle = WsMsgType("middle") WsEnd = WsMsgType("end") + WsImg = WsMsgType("img") ) type BizCode int diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index b1e195f4..dc61173b 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -5,6 +5,7 @@ import ( "bytes" "chatplus/core" "chatplus/core/types" + "chatplus/store" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" @@ -26,14 +27,16 @@ import ( ) const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。" +const TaskStorePrefix = "/tasks/" type ChatHandler struct { BaseHandler - db *gorm.DB + db *gorm.DB + leveldb *store.LevelDB } -func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler { - handler := ChatHandler{db: db} +func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB) *ChatHandler { + handler := ChatHandler{db: db, leveldb: levelDB} handler.App = app return &handler } @@ -133,7 +136,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { } // 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端 -func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error { +func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 var user model.User @@ -340,13 +343,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession if functionName == types.FuncMidJourney { key := utils.Sha256(data) // add task for MidJourney - h.App.MjTasks.Put(key, types.MjTask{ + h.App.MjTaskClients.Put(key, ws) + task := types.MjTask{ UserId: userVo.Id, RoleId: role.Id, Icon: role.Icon, Client: ws, ChatId: session.ChatId, - }) + } + err := h.leveldb.Put(TaskStorePrefix+key, task) + if err != nil { + logger.Error("error with store MidJourney task: ", err) + } content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 07782df8..365e5191 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -24,6 +24,7 @@ type Image struct { Width int `json:"width"` Height int `json:"height"` Size int `json:"size"` + Hash string `json:"hash"` } type MidJourneyHandler struct { @@ -44,18 +45,30 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { } var data struct { - Image Image `json:"image"` - Content string `json:"content"` - Status TaskStatus `json:"status"` + Type string `json:"type"` + MessageId string `json:"message_id"` + Image Image `json:"image"` + Content string `json:"content"` + Prompt string `json:"prompt"` + Status TaskStatus `json:"status"` + Key string `json:"key"` } - if err := c.ShouldBindJSON(&data); err != nil { + if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { resp.ERROR(c, types.InvalidArgs) return } + key := utils.Sha256(data.Prompt) + data.Key = key + // TODO: 如果绘画任务完成了则将该消息保存到当前会话的聊天历史记录 - sessionId := "u7blnft9zqisyrwidjb22j6b78iqc30lv9jtud3k9o" - wsClient := h.App.ChatClients.Get(sessionId) - utils.ReplyMessage(wsClient, "![](https://cdn.discordapp.com/attachments/1138713254718361633/1139482452579070053/lal603743923_A_Chinese_girl_walking_barefoot_on_the_beach_weari_df8b6dc0-3b13-478c-8dbb-983015d21661.png)") - logger.Infof("Data: %+v", data) + wsClient := h.App.MjTaskClients.Get(key) + if wsClient == nil { // 客户端断线,则丢弃 + resp.SUCCESS(c) + return + } + + // 推送消息到客户端 + // TODO: 增加绘画消息类型 + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsImg, Content: data}) resp.ERROR(c, "Error with CallBack") } diff --git a/api/handler/sms_handler.go b/api/handler/sms_handler.go index 2248800a..fba56a03 100644 --- a/api/handler/sms_handler.go +++ b/api/handler/sms_handler.go @@ -14,13 +14,13 @@ const CodeStorePrefix = "/verify/codes/" type SmsHandler struct { BaseHandler - db *store.LevelDB + leveldb *store.LevelDB sms *service.AliYunSmsService captcha *service.CaptchaService } func NewSmsHandler(app *core.AppServer, db *store.LevelDB, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler { - handler := &SmsHandler{db: db, sms: sms, captcha: captcha} + handler := &SmsHandler{leveldb: db, sms: sms, captcha: captcha} handler.App = app return handler } @@ -50,7 +50,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) { } // 存储验证码,等待后面注册验证 - err = h.db.Put(CodeStorePrefix+data.Mobile, code) + err = h.leveldb.Put(CodeStorePrefix+data.Mobile, code) if err != nil { resp.ERROR(c, "验证码保存失败") return diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index 81eae932..9c5225f8 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -22,11 +22,11 @@ type UserHandler struct { BaseHandler db *gorm.DB searcher *xdb.Searcher - levelDB *store.LevelDB + leveldb *store.LevelDB } func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher, levelDB *store.LevelDB) *UserHandler { - handler := &UserHandler{db: db, searcher: searcher, levelDB: levelDB} + handler := &UserHandler{db: db, searcher: searcher, leveldb: levelDB} handler.App = app return handler } @@ -60,7 +60,7 @@ func (h *UserHandler) Register(c *gin.Context) { key := CodeStorePrefix + data.Mobile if h.App.SysConfig.EnabledMsgService { var code int - err := h.levelDB.Get(key, &code) + err := h.leveldb.Get(key, &code) if err != nil || code != data.Code { logger.Info(code) resp.ERROR(c, "短信验证码错误") @@ -118,7 +118,7 @@ func (h *UserHandler) Register(c *gin.Context) { } if h.App.SysConfig.EnabledMsgService { - _ = h.levelDB.Delete(key) // 注册成功,删除短信验证码 + _ = h.leveldb.Delete(key) // 注册成功,删除短信验证码 } resp.SUCCESS(c, user) } @@ -366,7 +366,7 @@ func (h *UserHandler) BindMobile(c *gin.Context) { // 检查验证码 key := CodeStorePrefix + data.Mobile var code int - err := h.levelDB.Get(key, &code) + err := h.leveldb.Get(key, &code) if err != nil || code != data.Code { resp.ERROR(c, "短信验证码错误") return @@ -384,6 +384,6 @@ func (h *UserHandler) BindMobile(c *gin.Context) { return } - _ = h.levelDB.Delete(key) // 删除短信验证码 + _ = h.leveldb.Delete(key) // 删除短信验证码 resp.SUCCESS(c) } diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go index 4996466e..2c510e3f 100644 --- a/api/service/function/mid_journey.go +++ b/api/service/function/mid_journey.go @@ -36,12 +36,13 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { delete(params, "ar") } prompt = prompt + " --niji 5" + url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL) var res types.BizVo r, err := f.client.R(). SetHeader("Authorization", f.config.Token). SetHeader("Content-Type", "application/json"). SetBody(params). - SetSuccessResult(&res).Post(f.config.ApiURL) + SetSuccessResult(&res).Post(url) if err != nil || r.IsErrorState() { return "", fmt.Errorf("%v%v", r.String(), err) } diff --git a/api/utils/websocket.go b/api/utils/websocket.go index e161d97b..a1d0399d 100644 --- a/api/utils/websocket.go +++ b/api/utils/websocket.go @@ -9,20 +9,20 @@ import ( var logger = logger2.GetLogger() // ReplyChunkMessage 回复客户片段端消息 -func ReplyChunkMessage(client types.Client, message types.WsMessage) { +func ReplyChunkMessage(client *types.WsClient, message types.WsMessage) { msg, err := json.Marshal(message) if err != nil { logger.Errorf("Error for decoding json data: %v", err.Error()) return } - err = client.(*types.WsClient).Send(msg) + err = client.Send(msg) if err != nil { logger.Errorf("Error for reply message: %v", err.Error()) } } // ReplyMessage 回复客户端一条完整的消息 -func ReplyMessage(ws types.Client, message string) { +func ReplyMessage(ws *types.WsClient, message interface{}) { ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message}) ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})