feat: 数据库接口支持

This commit is contained in:
Rock Chin
2023-03-18 12:57:36 +00:00
parent 0490ad9207
commit d056cb6769

View File

@@ -54,20 +54,27 @@ class DatabaseManager:
`last_interact_timestamp` bigint not null,
`status` varchar(255) not null default 'on_going',
`default_prompt` text not null default '',
`prompt` text not null
`prompt` text not null,
`token_counts` text not null default '[]',
)
""")
# 检查sessions表是否存在`default_prompt`字段
# 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段
self.__execute__("PRAGMA table_info('sessions')")
columns = self.cursor.fetchall()
has_default_prompt = False
has_token_counts = False
for field in columns:
if field[1] == 'default_prompt':
has_default_prompt = True
if field[1] == 'token_counts':
has_token_counts = True
if has_default_prompt and has_token_counts:
break
if not has_default_prompt:
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
if not has_token_counts:
self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'")
self.__execute__("""
@@ -89,7 +96,7 @@ class DatabaseManager:
# session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str, default_prompt: str = ''):
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: list = []):
"""持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session
@@ -102,20 +109,20 @@ class DatabaseManager:
if count == 0:
sql = """
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`)
values (?, ?, ?, ?, ?, ?, ?)
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`)
values (?, ?, ?, ?, ?, ?, ?, ?)
"""
self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt, default_prompt))
last_interact_timestamp, prompt, default_prompt, json.dumps(token_counts)))
else:
sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
where `type` = ? and `number` = ? and `create_timestamp` = ?
"""
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
self.__execute__(sql, (last_interact_timestamp, prompt, json.dumps(token_counts), subject_type,
subject_number, create_timestamp))
# 显式关闭一个session
@@ -140,7 +147,7 @@ class DatabaseManager:
# 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time))
results = self.cursor.fetchall()
@@ -154,6 +161,7 @@ class DatabaseManager:
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
# 当且仅当最后一个该对象的会话是on_going状态时才会被加载
if status == 'on_going':
@@ -163,7 +171,8 @@ class DatabaseManager:
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt
'default_prompt': default_prompt,
'token_counts': json.loads(token_counts)
}
else:
if session_name in sessions:
@@ -175,7 +184,7 @@ class DatabaseManager:
def last_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1
""".format(session_name, cursor_timestamp))
@@ -192,6 +201,7 @@ class DatabaseManager:
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
@@ -199,14 +209,15 @@ class DatabaseManager:
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt
'default_prompt': default_prompt,
'token_counts': json.loads(token_counts)
}
# 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1
""".format(session_name, cursor_timestamp))
@@ -223,6 +234,7 @@ class DatabaseManager:
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
@@ -230,13 +242,14 @@ class DatabaseManager:
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt
'default_prompt': default_prompt,
'token_counts': json.loads(token_counts)
}
# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page))
results = self.cursor.fetchall()
@@ -250,6 +263,7 @@ class DatabaseManager:
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
sessions.append({
'subject_type': subject_type,
@@ -257,7 +271,8 @@ class DatabaseManager:
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt
'default_prompt': default_prompt,
'token_counts': json.loads(token_counts)
})
return sessions