From 6ef09c8ad54feb9946e4c71495759bd01477548c Mon Sep 17 00:00:00 2001 From: RockYang Date: Wed, 25 Sep 2024 18:43:12 +0800 Subject: [PATCH] add ws handler --- api/core/types/web.go | 26 +++++--- api/handler/chatimpl/chat_handler.go | 4 +- api/handler/chatimpl/openai_handler.go | 18 +++--- api/handler/markmap_handler.go | 8 +-- api/handler/ws_handler.go | 85 ++++++++++++++++++++++++++ api/service/ws_service.go | 7 +++ api/utils/net.go | 16 ++--- web/src/views/mobile/Index.vue | 2 +- web/src/views/mobile/Profile.vue | 3 +- 9 files changed, 136 insertions(+), 33 deletions(-) create mode 100644 api/handler/ws_handler.go create mode 100644 api/service/ws_service.go diff --git a/api/core/types/web.go b/api/core/types/web.go index eb9683bf..514b9b95 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -19,23 +19,35 @@ type BizVo struct { // ReplyMessage 对话回复消息结构 type ReplyMessage struct { - Type WsMsgType `json:"type"` // 消息类别,start, end, img + Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat + Type WsMsgType `json:"type"` // 消息类别 Content interface{} `json:"content"` } type WsMsgType string +type WsChannel string const ( - WsContent = WsMsgType("content") // 输出内容 - WsEnd = WsMsgType("end") - WsErr = WsMsgType("error") + WsMsgTypeContent = WsMsgType("content") // 输出内容 + WsMsgTypeEnd = WsMsgType("end") + WsMsgTypeErr = WsMsgType("error") + WsMsgTypePing = WsMsgType("ping") // 心跳消息 + + WsChat = WsChannel("chat") + WsMj = WsChannel("mj") + WsSd = WsChannel("sd") + WsDall = WsChannel("dall") + WsSuno = WsChannel("suno") + WsLuma = WsChannel("luma") ) // InputMessage 对话输入消息结构 type InputMessage struct { - Content string `json:"content"` - Tools []int `json:"tools"` // 允许调用工具列表 - Stream bool `json:"stream"` // 是否采用流式输出 + Channel WsChannel `json:"channel"` // 消息频道 + Type WsMsgType `json:"type"` // 消息类别 + Content string `json:"content"` + Tools []int `json:"tools"` // 允许调用工具列表 + Stream bool `json:"stream"` // 是否采用流式输出 } type BizCode int diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 8ea002fa..863ab15f 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -151,9 +151,9 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client) if err != nil { logger.Error(err) - utils.ReplyMessage(client, err.Error()) + utils.SendMessage(client, err.Error()) } else { - utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd}) + utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd}) logger.Infof("回答完毕: %v", message.Content) } diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index c0e03a06..a66dcd30 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -108,7 +108,7 @@ func (h *ChatHandler) sendOpenAiMessage( } if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { - utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") + utils.SendMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") break } @@ -136,7 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage( if res.Error == nil { toolCall = true callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) - utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg}) + utils.SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: callMsg}) contents = append(contents, callMsg) } continue @@ -153,8 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage( } else { content := responseBody.Choices[0].Delta.Content contents = append(contents, utils.InterfaceToString(content)) - utils.ReplyChunkMessage(ws, types.ReplyMessage{ - Type: types.WsContent, + utils.SendChunkMessage(ws, types.ReplyMessage{ + Type: types.WsMsgTypeContent, Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), }) } @@ -186,14 +186,14 @@ func (h *ChatHandler) sendOpenAiMessage( } if errMsg != "" || apiRes.Code != types.Success { msg := "调用函数工具出错:" + apiRes.Message + errMsg - utils.ReplyChunkMessage(ws, types.ReplyMessage{ - Type: types.WsContent, + utils.SendChunkMessage(ws, types.ReplyMessage{ + Type: types.WsMsgTypeContent, Content: msg, }) contents = append(contents, msg) } else { - utils.ReplyChunkMessage(ws, types.ReplyMessage{ - Type: types.WsContent, + utils.SendChunkMessage(ws, types.ReplyMessage{ + Type: types.WsMsgTypeContent, Content: apiRes.Data, }) contents = append(contents, utils.InterfaceToString(apiRes.Data)) @@ -226,7 +226,7 @@ func (h *ChatHandler) sendOpenAiMessage( if strings.HasPrefix(req.Model, "o1-") { content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content) } - utils.ReplyMessage(ws, content) + utils.SendMessage(ws, content) respVo.Usage.Prompt = prompt respVo.Usage.Content = content h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now()) diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index afdda797..0cfa1743 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -87,7 +87,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) { logger.Error(err) utils.ReplyErrorMessage(client, err.Error()) } else { - utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd}) + utils.SendMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd}) } } @@ -170,13 +170,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode break } - utils.ReplyChunkMessage(client, types.ReplyMessage{ - Type: types.WsContent, + utils.SendChunkMessage(client, types.ReplyMessage{ + Type: types.WsMsgTypeContent, Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), }) } // end for - utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd}) + utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd}) } else { body, _ := io.ReadAll(response.Body) diff --git a/api/handler/ws_handler.go b/api/handler/ws_handler.go new file mode 100644 index 00000000..8227c98e --- /dev/null +++ b/api/handler/ws_handler.go @@ -0,0 +1,85 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/service" + "geekai/store/model" + "geekai/utils" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "net/http" +) + +// Websocket 连接处理 handler + +type WebsocketHandler struct { + BaseHandler + wsService *service.WebsocketService +} + +func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService) *WebsocketHandler { + return &WebsocketHandler{ + BaseHandler: BaseHandler{App: app}, + wsService: s, + } +} + +func (h *WebsocketHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + clientId := c.Query("client") + client := types.NewWsClient(ws) + if userId == 0 { + _ = client.Send([]byte("Invalid user_id")) + c.Abort() + return + } + var user model.User + if err := h.DB.Where("id", userId).First(&user).Error; err != nil { + _ = client.Send([]byte("Invalid user_id")) + c.Abort() + return + } + + h.wsService.Clients.Put(clientId, client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) + go func() { + for { + _, msg, err := client.Receive() + if err != nil { + logger.Debugf("close connection: %s", client.Conn.RemoteAddr()) + client.Close() + } + + var message types.InputMessage + err = utils.JsonDecode(string(msg), &message) + if err != nil { + continue + } + + logger.Infof("Receive a message:%+v", message) + if message.Type == types.WsMsgTypePing { + _ = client.Send([]byte(`{"type":"pong"}`)) + } + + switch message.Channel { + + } + + } + }() +} diff --git a/api/service/ws_service.go b/api/service/ws_service.go new file mode 100644 index 00000000..65def76f --- /dev/null +++ b/api/service/ws_service.go @@ -0,0 +1,7 @@ +package service + +import "geekai/core/types" + +type WebsocketService struct { + Clients *types.LMap[string, *types.WsClient] // clientId => Client +} diff --git a/api/utils/net.go b/api/utils/net.go index d88a36c1..f103a870 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -19,8 +19,8 @@ import ( var logger = logger2.GetLogger() -// ReplyChunkMessage 回复客户片段端消息 -func ReplyChunkMessage(client *types.WsClient, message interface{}) { +// SendChunkMessage 回复客户片段端消息 +func SendChunkMessage(client *types.WsClient, message interface{}) { msg, err := json.Marshal(message) if err != nil { logger.Errorf("Error for decoding json data: %v", err.Error()) @@ -32,19 +32,19 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) { } } -// ReplyMessage 回复客户端一条完整的消息 -func ReplyMessage(ws *types.WsClient, message interface{}) { - ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message}) - ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd}) +// SendMessage 回复客户端一条完整的消息 +func SendMessage(ws *types.WsClient, message interface{}) { + SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message}) + SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeEnd}) } func ReplyContent(ws *types.WsClient, message interface{}) { - ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message}) + SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message}) } // ReplyErrorMessage 向客户端发送错误消息 func ReplyErrorMessage(ws *types.WsClient, message interface{}) { - ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message}) + SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeErr, Content: message}) } func DownloadImage(imageURL string, proxy string) ([]byte, error) { diff --git a/web/src/views/mobile/Index.vue b/web/src/views/mobile/Index.vue index dd5430dd..1a469202 100644 --- a/web/src/views/mobile/Index.vue +++ b/web/src/views/mobile/Index.vue @@ -109,7 +109,7 @@ onMounted(() => { }) const fetchApps = () => { - httpGet("/api/app/list/user").then((res) => { + httpGet("/api/app/list").then((res) => { const items = res.data // 处理 hello message for (let i = 0; i < items.length; i++) { diff --git a/web/src/views/mobile/Profile.vue b/web/src/views/mobile/Profile.vue index d6a9cf0c..630516b5 100644 --- a/web/src/views/mobile/Profile.vue +++ b/web/src/views/mobile/Profile.vue @@ -153,8 +153,7 @@ import {onMounted, ref} from "vue"; import {showFailToast, showLoadingToast, showNotify, showSuccessToast} from "vant"; import {httpGet, httpPost} from "@/utils/http"; -import Compressor from 'compressorjs'; -import {dateFormat, isWeChatBrowser, showLoginDialog} from "@/utils/libs"; +import {dateFormat, showLoginDialog} from "@/utils/libs"; import {ElMessage} from "element-plus"; import {checkSession, getSystemInfo} from "@/store/cache"; import {useRouter} from "vue-router";