From b06754506a11137726f5cdb571b17b8275d9a8cc Mon Sep 17 00:00:00 2001 From: Rock Chin Date: Thu, 8 Dec 2022 13:22:54 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E5=AE=8C=E5=96=84session=E7=9A=84?= =?UTF-8?q?=E7=BB=B4=E6=8A=A4=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/database/manager.py | 38 ++++++++++++++++++++++++++++++-------- pkg/openai/session.py | 9 ++++++++- pkg/qqbot/manager.py | 2 +- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index be805f09..1aca8d2b 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -1,3 +1,4 @@ +import threading import time import pymysql @@ -26,9 +27,17 @@ class DatabaseManager: self.reconnect() + heartbeat_proxy = threading.Thread(target=self.heartbeat, daemon=True) + heartbeat_proxy.start() + global inst inst = self + def heartbeat(self): + while True: + self.conn.ping(reconnect=True) + time.sleep(30) + def reconnect(self): self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, password=self.password, database=self.database, autocommit=True) @@ -43,6 +52,7 @@ class DatabaseManager: `number` bigint not null, `create_timestamp` bigint not null, `last_interact_timestamp` bigint not null, + `status` varchar(255) not null default 'on_going', `prompt` text not null ) """) @@ -70,11 +80,16 @@ class DatabaseManager: """.format(last_interact_timestamp, escape_string(prompt), subject_type, subject_number, create_timestamp)) + def explicit_close_session(self, session_name: str, create_timestamp: int): + self.cursor.execute(""" + update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} + """.format(session_name, create_timestamp)) + # 记载还没过期的session数据 def load_valid_sessions(self) -> dict: # 从数据库中加载所有还没过期的session self.cursor.execute(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `last_interact_timestamp` > {} """.format(int(time.time()) - config.session_expire_time)) results = self.cursor.fetchall() @@ -86,14 +101,21 @@ class DatabaseManager: create_timestamp = result[3] last_interact_timestamp = result[4] prompt = result[5] + status = result[6] + + # 当且仅当最后一个该对象的会话是on_going状态时,才会被加载 + if status == 'on_going': + sessions[session_name] = { + 'subject_type': subject_type, + 'subject_number': subject_number, + 'create_timestamp': create_timestamp, + 'last_interact_timestamp': last_interact_timestamp, + 'prompt': prompt + } + else: + if session_name in sessions: + del sessions[session_name] - sessions[session_name] = { - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt - } return sessions diff --git a/pkg/openai/session.py b/pkg/openai/session.py index aa0f1005..b271f509 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -6,6 +6,11 @@ import pkg.database.manager sessions = {} +class SessionOfflineStatus: + ON_GOING = 'on_going' + EXPLICITLY_CLOSED = 'explicitly_closed' + + def load_sessions(): global sessions @@ -89,9 +94,11 @@ class Session: db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, self.prompt) - def reset(self): + def reset(self, explicit: bool = False): if self.prompt != '': self.persistence() + if explicit: + pkg.database.manager.get_inst().explicit_close_session(self.name, self.create_timestamp) self.prompt = '' self.create_timestamp = int(time.time()) self.last_interact_timestamp = 0 diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index aa9a4d34..2172f4fe 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -63,7 +63,7 @@ class QQBotManager: if cmd == 'help': reply = "[bot]" + help_text elif cmd == 'reset': - pkg.openai.session.get_session(session_name).reset() + pkg.openai.session.get_session(session_name).reset(explicit=True) reply = "[bot]会话已重置" elif cmd == 'last': pass