From 8c4836dde4784426a2af7a6e88816135010b8d0c Mon Sep 17 00:00:00 2001 From: Rock Chin Date: Mon, 12 Dec 2022 23:07:15 +0800 Subject: [PATCH] =?UTF-8?q?debug:=20=E8=BE=93=E5=87=BA=E6=AF=8F=E4=B8=AASQ?= =?UTF-8?q?L?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/database/manager.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 5b806a8d..fa0ea46f 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -1,4 +1,6 @@ +import logging import time +from sqlite3 import Cursor from pymysql.converters import escape_string @@ -27,9 +29,15 @@ class DatabaseManager: # self.conn.isolation_level = None self.cursor = self.conn.cursor() + def execute(self, sql: str) -> Cursor: + c = self.cursor.execute(sql) + logging.debug('SQL: {}'.format(sql)) + self.conn.commit() + return c + # 初始化数据库的函数 def initialize_database(self): - self.cursor.execute(""" + self.execute(""" create table if not exists `sessions` ( `id` INTEGER PRIMARY KEY AUTOINCREMENT, `name` varchar(255) not null, @@ -41,7 +49,6 @@ class DatabaseManager: `prompt` text not null ) """) - self.conn.commit() print('Database initialized.') # session持久化 @@ -50,49 +57,44 @@ class DatabaseManager: # 检查是否已经有了此name和create_timestamp的session # 如果有,就更新prompt和last_interact_timestamp # 如果没有,就插入一条新的记录 - self.cursor.execute(""" + self.execute(""" select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {} """.format(subject_type, subject_number, create_timestamp)) count = self.cursor.fetchone()[0] if count == 0: - self.cursor.execute(""" + self.execute(""" insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`) values ('{}', '{}', {}, {}, {}, '{}') """.format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, last_interact_timestamp, escape_string(prompt))) - self.conn.commit() else: - self.cursor.execute(""" + self.execute(""" update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}' where `type` = '{}' and `number` = {} and `create_timestamp` = {} """.format(last_interact_timestamp, escape_string(prompt), subject_type, subject_number, create_timestamp)) - self.conn.commit() # 显式关闭一个session def explicit_close_session(self, session_name: str, create_timestamp: int): - self.cursor.execute(""" + self.execute(""" update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) - self.conn.commit() def set_session_ongoing(self, session_name: str, create_timestamp: int): - self.cursor.execute(""" + self.execute(""" update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) - self.conn.commit() # 设置session为过期 def set_session_expired(self, session_name: str, create_timestamp: int): - self.cursor.execute(""" + self.execute(""" update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) - self.conn.commit() # 从数据库加载还没过期的session数据 def load_valid_sessions(self) -> dict: # 从数据库中加载所有还没过期的session - self.cursor.execute(""" + self.execute(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `last_interact_timestamp` > {} """.format(int(time.time()) - config.session_expire_time)) @@ -125,7 +127,7 @@ class DatabaseManager: # 获取此session_name前一个session的数据 def last_session(self, session_name: str, cursor_timestamp: int): - self.cursor.execute(""" + self.execute(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc limit 1 @@ -154,7 +156,7 @@ class DatabaseManager: # 获取此session_name后一个session的数据 def next_session(self, session_name: str, cursor_timestamp: int): - self.cursor.execute(""" + self.execute(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc limit 1 @@ -182,7 +184,7 @@ class DatabaseManager: # 列出与某个对象的所有对话session def list_history(self, session_name: str, capacity: int, page: int): - self.cursor.execute(""" + self.execute(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} """.format(session_name, capacity, capacity * page))