mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: 数据库接口支持
This commit is contained in:
@@ -54,20 +54,27 @@ class DatabaseManager:
|
||||
`last_interact_timestamp` bigint not null,
|
||||
`status` varchar(255) not null default 'on_going',
|
||||
`default_prompt` text not null default '',
|
||||
`prompt` text not null
|
||||
`prompt` text not null,
|
||||
`token_counts` text not null default '[]',
|
||||
)
|
||||
""")
|
||||
|
||||
# 检查sessions表是否存在`default_prompt`字段
|
||||
# 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段
|
||||
self.__execute__("PRAGMA table_info('sessions')")
|
||||
columns = self.cursor.fetchall()
|
||||
has_default_prompt = False
|
||||
has_token_counts = False
|
||||
for field in columns:
|
||||
if field[1] == 'default_prompt':
|
||||
has_default_prompt = True
|
||||
if field[1] == 'token_counts':
|
||||
has_token_counts = True
|
||||
if has_default_prompt and has_token_counts:
|
||||
break
|
||||
if not has_default_prompt:
|
||||
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
|
||||
if not has_token_counts:
|
||||
self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'")
|
||||
|
||||
|
||||
self.__execute__("""
|
||||
@@ -89,7 +96,7 @@ class DatabaseManager:
|
||||
|
||||
# session持久化
|
||||
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
||||
last_interact_timestamp: int, prompt: str, default_prompt: str = ''):
|
||||
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: list = []):
|
||||
"""持久化指定session"""
|
||||
|
||||
# 检查是否已经有了此name和create_timestamp的session
|
||||
@@ -102,20 +109,20 @@ class DatabaseManager:
|
||||
if count == 0:
|
||||
|
||||
sql = """
|
||||
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`)
|
||||
values (?, ?, ?, ?, ?, ?, ?)
|
||||
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
self.__execute__(sql,
|
||||
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
||||
last_interact_timestamp, prompt, default_prompt))
|
||||
last_interact_timestamp, prompt, default_prompt, json.dumps(token_counts)))
|
||||
else:
|
||||
sql = """
|
||||
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
|
||||
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
|
||||
where `type` = ? and `number` = ? and `create_timestamp` = ?
|
||||
"""
|
||||
|
||||
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
|
||||
self.__execute__(sql, (last_interact_timestamp, prompt, json.dumps(token_counts), subject_type,
|
||||
subject_number, create_timestamp))
|
||||
|
||||
# 显式关闭一个session
|
||||
@@ -140,7 +147,7 @@ class DatabaseManager:
|
||||
# 从数据库中加载所有还没过期的session
|
||||
config = pkg.utils.context.get_config()
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||
from `sessions` where `last_interact_timestamp` > {}
|
||||
""".format(int(time.time()) - config.session_expire_time))
|
||||
results = self.cursor.fetchall()
|
||||
@@ -154,6 +161,7 @@ class DatabaseManager:
|
||||
prompt = result[5]
|
||||
status = result[6]
|
||||
default_prompt = result[7]
|
||||
token_counts = result[8]
|
||||
|
||||
# 当且仅当最后一个该对象的会话是on_going状态时,才会被加载
|
||||
if status == 'on_going':
|
||||
@@ -163,7 +171,8 @@ class DatabaseManager:
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt
|
||||
'default_prompt': default_prompt,
|
||||
'token_counts': json.loads(token_counts)
|
||||
}
|
||||
else:
|
||||
if session_name in sessions:
|
||||
@@ -175,7 +184,7 @@ class DatabaseManager:
|
||||
def last_session(self, session_name: str, cursor_timestamp: int):
|
||||
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
|
||||
limit 1
|
||||
""".format(session_name, cursor_timestamp))
|
||||
@@ -192,6 +201,7 @@ class DatabaseManager:
|
||||
prompt = result[5]
|
||||
status = result[6]
|
||||
default_prompt = result[7]
|
||||
token_counts = result[8]
|
||||
|
||||
return {
|
||||
'subject_type': subject_type,
|
||||
@@ -199,14 +209,15 @@ class DatabaseManager:
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt
|
||||
'default_prompt': default_prompt,
|
||||
'token_counts': json.loads(token_counts)
|
||||
}
|
||||
|
||||
# 获取此session_name后一个session的数据
|
||||
def next_session(self, session_name: str, cursor_timestamp: int):
|
||||
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
|
||||
limit 1
|
||||
""".format(session_name, cursor_timestamp))
|
||||
@@ -223,6 +234,7 @@ class DatabaseManager:
|
||||
prompt = result[5]
|
||||
status = result[6]
|
||||
default_prompt = result[7]
|
||||
token_counts = result[8]
|
||||
|
||||
return {
|
||||
'subject_type': subject_type,
|
||||
@@ -230,13 +242,14 @@ class DatabaseManager:
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt
|
||||
'default_prompt': default_prompt,
|
||||
'token_counts': json.loads(token_counts)
|
||||
}
|
||||
|
||||
# 列出与某个对象的所有对话session
|
||||
def list_history(self, session_name: str, capacity: int, page: int):
|
||||
self.__execute__("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
|
||||
""".format(session_name, capacity, capacity * page))
|
||||
results = self.cursor.fetchall()
|
||||
@@ -250,6 +263,7 @@ class DatabaseManager:
|
||||
prompt = result[5]
|
||||
status = result[6]
|
||||
default_prompt = result[7]
|
||||
token_counts = result[8]
|
||||
|
||||
sessions.append({
|
||||
'subject_type': subject_type,
|
||||
@@ -257,7 +271,8 @@ class DatabaseManager:
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt,
|
||||
'default_prompt': default_prompt
|
||||
'default_prompt': default_prompt,
|
||||
'token_counts': json.loads(token_counts)
|
||||
})
|
||||
|
||||
return sessions
|
||||
|
||||
Reference in New Issue
Block a user