diff --git a/main.py b/main.py index d4be11d4..18c002f7 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,17 @@ import os import shutil +import sys + +import pkg.openai.manager +import pkg.database.manager +import pkg.openai.session + + +def init_db(): + import config + database = pkg.database.manager.DatabaseManager(**config.mysql_config) + + database.initialize_database() def main(): @@ -12,8 +24,17 @@ def main(): assert os.path.exists('config.py') import config - # print(config.mirai_http_api_config) + # 主启动流程 + openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params) + + database = pkg.database.manager.DatabaseManager(**config.mysql_config) + + # 加载所有未超时的session + pkg.openai.session.load_sessions() if __name__ == '__main__': + if len(sys.argv) > 1 and sys.argv[1] == 'init_db': + init_db() + sys.exit(0) main() diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 2a814338..d5716f09 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -1,5 +1,9 @@ +import time + import pymysql +import config + inst = None @@ -26,9 +30,70 @@ class DatabaseManager: def reconnect(self): self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, password=self.password, - database=self.database) + database=self.database, autocommit=True) self.cursor = self.conn.cursor() + def initialize_database(self): + self.cursor.execute(""" + create table if not exists `sessions` ( + `id` bigint not null auto_increment primary key, + `name` varchar(255) not null, + `type` varchar(255) not null, + `number` bigint not null, + `create_timestamp` bigint not null, + `last_interact_timestamp` bigint not null, + `prompt` text not null + ) + """) + print('Database initialized.') + + def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, + last_interact_timestamp: int, prompt: str): + # 检查是否已经有了此name和create_timestamp的session + # 如果有,就更新prompt和last_interact_timestamp + # 如果没有,就插入一条新的记录 + self.cursor.execute(""" + select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {} + """.format(subject_type, subject_number, create_timestamp)) + count = self.cursor.fetchone()[0] + if count == 0: + self.cursor.execute(""" + insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`) + values ('{}', '{}', {}, {}, {}, '{}') + """.format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, + last_interact_timestamp, prompt)) + else: + self.cursor.execute(""" + update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}' + where `type` = '{}' and `number` = {} and `create_timestamp` = {} + """.format(last_interact_timestamp, prompt, subject_type, subject_number, create_timestamp)) + + # 记载还没过期的session数据 + def load_valid_sessions(self) -> dict: + # 从数据库中加载所有还没过期的session + self.cursor.execute(""" + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt` + from `sessions` where `last_interact_timestamp` > {} + """.format(int(time.time()) - config.session_expire_time)) + results = self.cursor.fetchall() + sessions = {} + for result in results: + session_name = result[0] + subject_type = result[1] + subject_number = result[2] + create_timestamp = result[3] + last_interact_timestamp = result[4] + prompt = result[5] + + sessions[session_name] = { + 'subject_type': subject_type, + 'subject_number': subject_number, + 'create_timestamp': create_timestamp, + 'last_interact_timestamp': last_interact_timestamp, + 'prompt': prompt + } + return sessions + def get_inst() -> DatabaseManager: global inst diff --git a/pkg/openai/session.py b/pkg/openai/session.py index e01acdcf..a3e9bcdc 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -1,9 +1,41 @@ import time import pkg.openai.manager +import pkg.database.manager + +sessions = {} -session = {} +def load_sessions(): + global sessions + + db_inst = pkg.database.manager.get_inst() + + session_data = db_inst.load_valid_sessions() + + for session_name in session_data: + temp_session = Session(session_name) + temp_session.name = session_name + temp_session.create_timestamp = session_data[session_name]['create_timestamp'] + temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp'] + temp_session.prompt = session_data[session_name]['prompt'] + + sessions[session_name] = temp_session + + +def get_session(session_name: str): + global sessions + if session_name not in sessions: + sessions[session_name] = Session(session_name) + return sessions[session_name] + + +def dump_session(session_name: str): + global sessions + if session_name in sessions: + assert isinstance(sessions[session_name], Session) + sessions[session_name].persistence() + del sessions[session_name] # 通用的OpenAI API交互session @@ -23,17 +55,14 @@ class Session: self.name = name self.create_timestamp = int(time.time()) - global session - session[name] = self - # 请求回复 # 这个函数是阻塞的 def append(self, text: str) -> str: - self.prompt += self.user_name + ':' + text + '\n'+self.bot_name+':' + self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':' self.last_interact_timestamp = int(time.time()) # 向API请求补全 - response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name+':') + response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name + ':') # print(response) # 处理回复 @@ -50,4 +79,12 @@ class Session: return res_ans def persistence(self): - pass + db_inst = pkg.database.manager.get_inst() + + name_spt = self.name.split('_') + + subject_type = name_spt[0] + subject_number = int(name_spt[1]) + + db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, + self.prompt)