feat: 分离储存会话情景预设和对话内容

This commit is contained in:
Rock Chin
2023-03-11 23:44:22 +08:00
parent e9155e836f
commit 9cd7e49804
2 changed files with 59 additions and 28 deletions

View File

@@ -52,10 +52,23 @@ class DatabaseManager:
`create_timestamp` bigint not null,
`last_interact_timestamp` bigint not null,
`status` varchar(255) not null default 'on_going',
`default_prompt` text not null default '',
`prompt` text not null
)
""")
# 检查sessions表是否存在`default_prompt`字段
self.__execute__("PRAGMA table_info('sessions')")
columns = self.cursor.fetchall()
has_default_prompt = False
for field in columns:
if field[1] == 'default_prompt':
has_default_prompt = True
break
if not has_default_prompt:
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
self.__execute__("""
create table if not exists `account_fee`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -75,7 +88,7 @@ class DatabaseManager:
# session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str):
last_interact_timestamp: int, prompt: str, default_prompt: str = ''):
"""持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session
@@ -88,13 +101,13 @@ class DatabaseManager:
if count == 0:
sql = """
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`)
values (?, ?, ?, ?, ?, ?)
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`)
values (?, ?, ?, ?, ?, ?, ?)
"""
self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt))
last_interact_timestamp, prompt, default_prompt))
else:
sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
@@ -126,7 +139,7 @@ class DatabaseManager:
# 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time))
results = self.cursor.fetchall()
@@ -139,6 +152,7 @@ class DatabaseManager:
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
# 当且仅当最后一个该对象的会话是on_going状态时才会被加载
if status == 'on_going':
@@ -147,7 +161,8 @@ class DatabaseManager:
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
'prompt': prompt,
'default_prompt': default_prompt
}
else:
if session_name in sessions:
@@ -159,7 +174,7 @@ class DatabaseManager:
def last_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1
""".format(session_name, cursor_timestamp))
@@ -175,20 +190,22 @@ class DatabaseManager:
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
return {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
'prompt': prompt,
'default_prompt': default_prompt
}
# 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1
""".format(session_name, cursor_timestamp))
@@ -204,19 +221,21 @@ class DatabaseManager:
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
return {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
'prompt': prompt,
'default_prompt': default_prompt
}
# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page))
results = self.cursor.fetchall()
@@ -229,13 +248,15 @@ class DatabaseManager:
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
sessions.append({
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
'prompt': prompt,
'default_prompt': default_prompt
})
return sessions