feat: 使用SQLite作为数据库

This commit is contained in:
Rock Chin
2022-12-11 13:58:47 +08:00
parent 1695604226
commit 9c7271c1f8
3 changed files with 14 additions and 40 deletions

View File

@@ -15,20 +15,6 @@ mirai_http_api_config = {
"qq": 0 "qq": 0
} }
# [必需] MySQL数据库的配置
# host: 数据库地址
# port: 数据库端口
# user: 数据库用户名
# password: 数据库密码
# database: 数据库名
mysql_config = {
"host": "",
"port": 3306,
"user": "",
"password": "",
"database": ""
}
# [必需] OpenAI的配置 # [必需] OpenAI的配置
# api_key: OpenAI的API Key # api_key: OpenAI的API Key
openai_config = { openai_config = {

View File

@@ -17,9 +17,8 @@ log_colors_config = {
def init_db(): def init_db():
import config
import pkg.database.manager import pkg.database.manager
database = pkg.database.manager.DatabaseManager(**config.mysql_config) database = pkg.database.manager.DatabaseManager()
database.initialize_database() database.initialize_database()
@@ -54,7 +53,7 @@ def main():
# 主启动流程 # 主启动流程
openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params) openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params)
database = pkg.database.manager.DatabaseManager(**config.mysql_config) database = pkg.database.manager.DatabaseManager()
# 加载所有未超时的session # 加载所有未超时的session
pkg.openai.session.load_sessions() pkg.openai.session.load_sessions()

View File

@@ -1,52 +1,35 @@
import threading
import time import time
import pymysql
from pymysql.converters import escape_string from pymysql.converters import escape_string
import sqlite3
import config import config
inst = None inst = None
class DatabaseManager: class DatabaseManager:
host = ''
port = 0
user = ''
password = ''
database = ''
conn = None conn = None
cursor = None cursor = None
def __init__(self, host: str, port: int, user: str, password: str, database: str): def __init__(self):
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
self.reconnect() self.reconnect()
heartbeat_proxy = threading.Thread(target=self.heartbeat, daemon=True)
heartbeat_proxy.start()
global inst global inst
inst = self inst = self
def heartbeat(self):
while True:
time.sleep(30)
self.conn.ping(reconnect=True)
def reconnect(self): def reconnect(self):
self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, password=self.password, self.conn = sqlite3.connect('database.db', check_same_thread=False)
database=self.database, autocommit=True) # self.conn.isolation_level = None
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor()
def initialize_database(self): def initialize_database(self):
self.cursor.execute(""" self.cursor.execute("""
create table if not exists `sessions` ( create table if not exists `sessions` (
`id` bigint not null auto_increment primary key, `id` INTEGER PRIMARY KEY AUTOINCREMENT,
`name` varchar(255) not null, `name` varchar(255) not null,
`type` varchar(255) not null, `type` varchar(255) not null,
`number` bigint not null, `number` bigint not null,
@@ -56,6 +39,7 @@ class DatabaseManager:
`prompt` text not null `prompt` text not null
) )
""") """)
self.conn.commit()
print('Database initialized.') print('Database initialized.')
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
@@ -73,27 +57,32 @@ class DatabaseManager:
values ('{}', '{}', {}, {}, {}, '{}') values ('{}', '{}', {}, {}, {}, '{}')
""".format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, """.format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, escape_string(prompt))) last_interact_timestamp, escape_string(prompt)))
self.conn.commit()
else: else:
self.cursor.execute(""" self.cursor.execute("""
update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}' update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}'
where `type` = '{}' and `number` = {} and `create_timestamp` = {} where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(last_interact_timestamp, escape_string(prompt), subject_type, """.format(last_interact_timestamp, escape_string(prompt), subject_type,
subject_number, create_timestamp)) subject_number, create_timestamp))
self.conn.commit()
def explicit_close_session(self, session_name: str, create_timestamp: int): def explicit_close_session(self, session_name: str, create_timestamp: int):
self.cursor.execute(""" self.cursor.execute("""
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp)) """.format(session_name, create_timestamp))
self.conn.commit()
def set_session_ongoing(self, session_name: str, create_timestamp: int): def set_session_ongoing(self, session_name: str, create_timestamp: int):
self.cursor.execute(""" self.cursor.execute("""
update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {} update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp)) """.format(session_name, create_timestamp))
self.conn.commit()
def set_session_expired(self, session_name: str, create_timestamp: int): def set_session_expired(self, session_name: str, create_timestamp: int):
self.cursor.execute(""" self.cursor.execute("""
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {} update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp)) """.format(session_name, create_timestamp))
self.conn.commit()
# 记载还没过期的session数据 # 记载还没过期的session数据
def load_valid_sessions(self) -> dict: def load_valid_sessions(self) -> dict: