diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 999d7315..d76dc0cf 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -54,20 +54,27 @@ class DatabaseManager: `last_interact_timestamp` bigint not null, `status` varchar(255) not null default 'on_going', `default_prompt` text not null default '', - `prompt` text not null + `prompt` text not null, + `token_counts` text not null default '[]', ) """) - # 检查sessions表是否存在`default_prompt`字段 + # 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段 self.__execute__("PRAGMA table_info('sessions')") columns = self.cursor.fetchall() has_default_prompt = False + has_token_counts = False for field in columns: if field[1] == 'default_prompt': has_default_prompt = True + if field[1] == 'token_counts': + has_token_counts = True + if has_default_prompt and has_token_counts: break if not has_default_prompt: self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''") + if not has_token_counts: + self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'") self.__execute__(""" @@ -89,7 +96,7 @@ class DatabaseManager: # session持久化 def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, - last_interact_timestamp: int, prompt: str, default_prompt: str = ''): + last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: list = []): """持久化指定session""" # 检查是否已经有了此name和create_timestamp的session @@ -102,20 +109,20 @@ class DatabaseManager: if count == 0: sql = """ - insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`) - values (?, ?, ?, ?, ?, ?, ?) + insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`) + values (?, ?, ?, ?, ?, ?, ?, ?) """ self.__execute__(sql, ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, - last_interact_timestamp, prompt, default_prompt)) + last_interact_timestamp, prompt, default_prompt, json.dumps(token_counts))) else: sql = """ - update `sessions` set `last_interact_timestamp` = ?, `prompt` = ? + update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ? where `type` = ? and `number` = ? and `create_timestamp` = ? """ - self.__execute__(sql, (last_interact_timestamp, prompt, subject_type, + self.__execute__(sql, (last_interact_timestamp, prompt, json.dumps(token_counts), subject_type, subject_number, create_timestamp)) # 显式关闭一个session @@ -140,7 +147,7 @@ class DatabaseManager: # 从数据库中加载所有还没过期的session config = pkg.utils.context.get_config() self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `last_interact_timestamp` > {} """.format(int(time.time()) - config.session_expire_time)) results = self.cursor.fetchall() @@ -154,6 +161,7 @@ class DatabaseManager: prompt = result[5] status = result[6] default_prompt = result[7] + token_counts = result[8] # 当且仅当最后一个该对象的会话是on_going状态时,才会被加载 if status == 'on_going': @@ -163,7 +171,8 @@ class DatabaseManager: 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, - 'default_prompt': default_prompt + 'default_prompt': default_prompt, + 'token_counts': json.loads(token_counts) } else: if session_name in sessions: @@ -175,7 +184,7 @@ class DatabaseManager: def last_session(self, session_name: str, cursor_timestamp: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc limit 1 """.format(session_name, cursor_timestamp)) @@ -192,6 +201,7 @@ class DatabaseManager: prompt = result[5] status = result[6] default_prompt = result[7] + token_counts = result[8] return { 'subject_type': subject_type, @@ -199,14 +209,15 @@ class DatabaseManager: 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, - 'default_prompt': default_prompt + 'default_prompt': default_prompt, + 'token_counts': json.loads(token_counts) } # 获取此session_name后一个session的数据 def next_session(self, session_name: str, cursor_timestamp: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc limit 1 """.format(session_name, cursor_timestamp)) @@ -223,6 +234,7 @@ class DatabaseManager: prompt = result[5] status = result[6] default_prompt = result[7] + token_counts = result[8] return { 'subject_type': subject_type, @@ -230,13 +242,14 @@ class DatabaseManager: 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, - 'default_prompt': default_prompt + 'default_prompt': default_prompt, + 'token_counts': json.loads(token_counts) } # 列出与某个对象的所有对话session def list_history(self, session_name: str, capacity: int, page: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} """.format(session_name, capacity, capacity * page)) results = self.cursor.fetchall() @@ -250,6 +263,7 @@ class DatabaseManager: prompt = result[5] status = result[6] default_prompt = result[7] + token_counts = result[8] sessions.append({ 'subject_type': subject_type, @@ -257,7 +271,8 @@ class DatabaseManager: 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, - 'default_prompt': default_prompt + 'default_prompt': default_prompt, + 'token_counts': json.loads(token_counts) }) return sessions