diff --git a/README.md b/README.md index 7451c22c..4b2a63f9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # QChatGPT 通过调用OpenAI GPT-3模型提供的Completion API来实现一个更加智能的QQ机器人 +使用SQLite储存会话内容,保证回复内容符合上下文 已部署的测试机器人QQ: 960164003 交流群、答疑群: 204785790 无需云服务器,在个人电脑上即可部署 @@ -9,7 +10,6 @@ - [Mirai](https://github.com/mamoe/mirai) 高效率 QQ 机器人支持库 - [YiriMirai](https://github.com/YiriMiraiProject/YiriMirai) 一个轻量级、低耦合的基于 mirai-api-http 的 Python SDK。 -- PyMySQL MySQL驱动 - [OpenAI API](https://openai.com/api/) OpenAI API ## 项目结构 @@ -42,11 +42,7 @@ 按照[此教程](https://yiri-mirai.wybxc.cc/tutorials/01/configuration)配置Mirai及YiriMirai -### 3. 配置MySQL数据库 - -安装MySQL数据库,创建数据库`qchatgpt` - -### 4. 配置此程序 +### 3. 配置主程序 1. 克隆此项目 diff --git a/config-template.py b/config-template.py index 39bff7a8..a34ae187 100644 --- a/config-template.py +++ b/config-template.py @@ -15,20 +15,6 @@ mirai_http_api_config = { "qq": 0 } -# [必需] MySQL数据库的配置 -# host: 数据库地址 -# port: 数据库端口 -# user: 数据库用户名 -# password: 数据库密码 -# database: 数据库名 -mysql_config = { - "host": "", - "port": 3306, - "user": "", - "password": "", - "database": "" -} - # [必需] OpenAI的配置 # api_key: OpenAI的API Key openai_config = { diff --git a/main.py b/main.py index f13921c8..1b581bfa 100644 --- a/main.py +++ b/main.py @@ -17,9 +17,8 @@ log_colors_config = { def init_db(): - import config import pkg.database.manager - database = pkg.database.manager.DatabaseManager(**config.mysql_config) + database = pkg.database.manager.DatabaseManager() database.initialize_database() @@ -54,7 +53,7 @@ def main(): # 主启动流程 openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params) - database = pkg.database.manager.DatabaseManager(**config.mysql_config) + database = pkg.database.manager.DatabaseManager() # 加载所有未超时的session pkg.openai.session.load_sessions() diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 1e8d79c5..9d46f828 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -1,52 +1,35 @@ -import threading import time -import pymysql from pymysql.converters import escape_string +import sqlite3 + import config inst = None class DatabaseManager: - host = '' - port = 0 - user = '' - password = '' - database = '' conn = None cursor = None - def __init__(self, host: str, port: int, user: str, password: str, database: str): - self.host = host - self.port = port - self.user = user - self.password = password - self.database = database + def __init__(self): self.reconnect() - heartbeat_proxy = threading.Thread(target=self.heartbeat, daemon=True) - heartbeat_proxy.start() global inst inst = self - def heartbeat(self): - while True: - time.sleep(30) - self.conn.ping(reconnect=True) - def reconnect(self): - self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, password=self.password, - database=self.database, autocommit=True) + self.conn = sqlite3.connect('database.db', check_same_thread=False) + # self.conn.isolation_level = None self.cursor = self.conn.cursor() def initialize_database(self): self.cursor.execute(""" create table if not exists `sessions` ( - `id` bigint not null auto_increment primary key, + `id` INTEGER PRIMARY KEY AUTOINCREMENT, `name` varchar(255) not null, `type` varchar(255) not null, `number` bigint not null, @@ -56,6 +39,7 @@ class DatabaseManager: `prompt` text not null ) """) + self.conn.commit() print('Database initialized.') def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, @@ -73,27 +57,32 @@ class DatabaseManager: values ('{}', '{}', {}, {}, {}, '{}') """.format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, last_interact_timestamp, escape_string(prompt))) + self.conn.commit() else: self.cursor.execute(""" update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}' where `type` = '{}' and `number` = {} and `create_timestamp` = {} """.format(last_interact_timestamp, escape_string(prompt), subject_type, subject_number, create_timestamp)) + self.conn.commit() def explicit_close_session(self, session_name: str, create_timestamp: int): self.cursor.execute(""" update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) + self.conn.commit() def set_session_ongoing(self, session_name: str, create_timestamp: int): self.cursor.execute(""" update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) + self.conn.commit() def set_session_expired(self, session_name: str, create_timestamp: int): self.cursor.execute(""" update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) + self.conn.commit() # 记载还没过期的session数据 def load_valid_sessions(self) -> dict: