diff --git a/server/chat_handler.go b/server/chat_handler.go index 47547551..19929d41 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -42,6 +42,9 @@ func (s *Server) ChatHandle(c *gin.Context) { c.Abort() return } + // 保存会话连接 + s.ChatClients[sessionId] = client + // 加载历史消息,如果历史消息为空则发送打招呼消息 _, err = GetChatHistory(session.Username, roleKey) if err != nil { @@ -53,6 +56,7 @@ func (s *Server) ChatHandle(c *gin.Context) { if err != nil { logger.Error(err) client.Close() + delete(s.ChatClients, sessionId) return } diff --git a/server/server.go b/server/server.go index 716d1230..b0809c2e 100644 --- a/server/server.go +++ b/server/server.go @@ -39,9 +39,10 @@ type Server struct { // 保存 Websocket 会话 Username, 每个 Username 只能连接一次 // 防止第三方直接连接 socket 调用 OpenAI API - ChatSession map[string]types.ChatSession - ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 - DebugMode bool // 是否开启调试模式 + ChatSession map[string]types.ChatSession //map[sessionId]User + ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 + ChatClients map[string]*WsClient // Websocket 连接集合 + DebugMode bool // 是否开启调试模式 } func NewServer(configPath string) (*Server, error) { @@ -65,6 +66,7 @@ func NewServer(configPath string) (*Server, error) { ConfigPath: configPath, ChatContexts: make(map[string]types.ChatContext, 16), ChatSession: make(map[string]types.ChatSession), + ChatClients: make(map[string]*WsClient), ApiKeyAccessStat: make(map[string]int64), }, nil } @@ -84,6 +86,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { engine.POST("/test", s.TestHandle) engine.GET("/api/session/get", s.GetSessionHandle) engine.POST("/api/login", s.LoginHandle) + engine.POST("/api/logout", s.LogoutHandle) engine.Any("/api/chat", s.ChatHandle) engine.POST("api/chat/history", s.GetChatHistoryHandle) engine.POST("api/chat/history/clear", s.ClearHistoryHandle) @@ -307,3 +310,34 @@ func (s *Server) LoginHandle(c *gin.Context) { SessionId string `json:"session_id"` }{User: *user, SessionId: sessionId}}) } + +// LogoutHandle 注销 +func (s *Server) LogoutHandle(c *gin.Context) { + var data struct { + Opt string `json:"opt"` + } + err := json.NewDecoder(c.Request.Body).Decode(&data) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) + return + } + + if data.Opt == "logout" { + sessionId := c.GetHeader(types.TokenName) + session := sessions.Default(c) + session.Delete(sessionId) + err := session.Save() + if err != nil { + logger.Error("Error for save session: ", err) + } + // 删除 websocket 会话列表 + delete(s.ChatSession, sessionId) + // 关闭 socket 连接 + if client, ok := s.ChatClients[sessionId]; ok { + client.Close() + } + c.JSON(http.StatusOK, types.BizVo{Code: types.Success}) + } else { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Hack attempt!"}) + } +} diff --git a/web/public/images/user-info.jpg b/web/public/images/user-info.jpg new file mode 100644 index 00000000..580598a1 Binary files /dev/null and b/web/public/images/user-info.jpg differ diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 5c63e902..55feb2b6 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -14,7 +14,7 @@ - +