From cb2b01127b86d71691c14868c5f81100de464e35 Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 28 Mar 2023 18:13:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9A=E6=9C=9F=E6=B8=85=E7=90=86=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E8=81=8A=E5=A4=A9=E4=BC=9A=E8=AF=9D=E4=B8=8A=E4=B8=8B?= =?UTF-8?q?=E6=96=87=20ChatContext?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat_handler.go | 12 ++++++++---- server/config_handler.go | 14 +++++++++++++- server/server.go | 23 +++++++++++++++++++---- types/chat.go | 6 ++++++ types/config.go | 13 +++++++------ utils/config.go | 13 +++++++------ 6 files changed, 60 insertions(+), 21 deletions(-) diff --git a/server/chat_handler.go b/server/chat_handler.go index 617a4885..4a43eba4 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -80,9 +80,9 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro Stream: true, } var context []types.Message - var key = session.SessionId + role.Name - if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext { - context = v + var ctxKey = fmt.Sprintf("%s-%s", session.SessionId, role.Key) + if v, ok := s.ChatContexts[ctxKey]; ok && s.Config.Chat.EnableContext { + context = v.Messages } else { context = role.Context } @@ -206,7 +206,11 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro context = append(context, useMsg) message.Content = strings.Join(contents, "") context = append(context, message) - s.ChatContext[key] = context + // 更新上下文消息 + s.ChatContexts[ctxKey] = types.ChatContext{ + Messages: context, + LastAccessTime: time.Now().Unix(), + } // 追加历史消息 if user.EnableHistory { diff --git a/server/config_handler.go b/server/config_handler.go index 40512288..11c62b45 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -63,6 +63,18 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { s.Config.Chat.EnableContext = v } + if expireTime, ok := data["chat_context_expire_time"]; ok { + v, err := strconv.Atoi(expireTime) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{ + Code: types.InvalidParams, + Message: "chat_context_expire_time must be a integer parameter", + }) + return + } + s.Config.Chat.ChatContextExpireTime = v + } + // enable auth if enableAuth, ok := data["enable_auth"]; ok { v, err := strconv.ParseBool(enableAuth) @@ -123,7 +135,7 @@ func (s *Server) AddUserHandle(c *gin.Context) { return } - user := types.User{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls} + user := types.User{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls, EnableHistory: data.EnableHistory} err = PutUser(user) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) diff --git a/server/server.go b/server/server.go index 125e6ab0..c534c3c4 100644 --- a/server/server.go +++ b/server/server.go @@ -15,6 +15,7 @@ import ( "path/filepath" "runtime/debug" "strings" + "time" ) var logger = logger2.GetLogger() @@ -31,9 +32,9 @@ func (s StaticFile) Open(name string) (fs.File, error) { } type Server struct { - Config *types.Config - ConfigPath string - ChatContext map[string][]types.Message // 聊天上下文 [SessionID] => []Messages + Config *types.Config + ConfigPath string + ChatContexts map[string]types.ChatContext // 聊天上下文 [SessionID+ChatRole] => ChatContext // 保存 Websocket 会话 Username, 每个 Username 只能连接一次 // 防止第三方直接连接 socket 调用 OpenAI API @@ -61,7 +62,7 @@ func NewServer(configPath string) (*Server, error) { return &Server{ Config: config, ConfigPath: configPath, - ChatContext: make(map[string][]types.Message, 16), + ChatContexts: make(map[string]types.ChatContext, 16), ChatSession: make(map[string]types.ChatSession), ApiKeyAccessStat: make(map[string]int64), }, nil @@ -111,6 +112,20 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { path: path, })) + // 定时清理过期的会话 + go func() { + for { + for key, context := range s.ChatContexts { + // 清理超过 60min 没有更新,则表示为过期会话 + if time.Now().Unix()-context.LastAccessTime > 3600 { + logger.Infof("清理会话上下文: %s", key) + delete(s.ChatContexts, key) + } + } + time.Sleep(time.Second * 5) // 每隔 5 秒钟清理一次 + } + }() + logger.Infof("http://%s", s.Config.Listen) err := engine.Run(s.Config.Listen) diff --git a/types/chat.go b/types/chat.go index ceed1a69..c79168c8 100644 --- a/types/chat.go +++ b/types/chat.go @@ -48,6 +48,12 @@ type ChatSession struct { Username string `json:"user"` // 当前登录的 user } +// ChatContext 聊天上下文 +type ChatContext struct { + Messages []Message + LastAccessTime int64 // 最后一次访问上下文时间 +} + func GetDefaultChatRole() map[string]ChatRole { return map[string]ChatRole{ "gpt": { diff --git a/types/config.go b/types/config.go index 5b620815..5e429538 100644 --- a/types/config.go +++ b/types/config.go @@ -22,12 +22,13 @@ type User struct { // Chat configs struct type Chat struct { - ApiURL string - ApiKeys []string - Model string - Temperature float32 - MaxTokens int - EnableContext bool // 是否保持聊天上下文 + ApiURL string + ApiKeys []string + Model string + Temperature float32 + MaxTokens int + EnableContext bool // 是否保持聊天上下文 + ChatContextExpireTime int // 聊天上下文过期时间,单位:秒 } // Session configs struct diff --git a/utils/config.go b/utils/config.go index 22048329..9a00401a 100644 --- a/utils/config.go +++ b/utils/config.go @@ -27,12 +27,13 @@ func NewDefaultConfig() *types.Config { SameSite: http.SameSiteLaxMode, }, Chat: types.Chat{ - ApiURL: "https://api.openai.com/v1/chat/completions", - ApiKeys: []string{""}, - Model: "gpt-3.5-turbo", - MaxTokens: 1024, - Temperature: 0.9, - EnableContext: true, + ApiURL: "https://api.openai.com/v1/chat/completions", + ApiKeys: []string{""}, + Model: "gpt-3.5-turbo", + MaxTokens: 1024, + Temperature: 0.9, + EnableContext: true, + ChatContextExpireTime: 3600, }, } }