From 9cd7e49804c7ffaf16a9cc3c141c42b66fb772f8 Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sat, 11 Mar 2023 23:44:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=86=E7=A6=BB=E5=82=A8=E5=AD=98?= =?UTF-8?q?=E4=BC=9A=E8=AF=9D=E6=83=85=E6=99=AF=E9=A2=84=E8=AE=BE=E5=92=8C?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/database/manager.py | 45 ++++++++++++++++++++++++++++++----------- pkg/openai/session.py | 42 +++++++++++++++++++++++--------------- 2 files changed, 59 insertions(+), 28 deletions(-) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 5fde3c29..cf452a5c 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -52,10 +52,23 @@ class DatabaseManager: `create_timestamp` bigint not null, `last_interact_timestamp` bigint not null, `status` varchar(255) not null default 'on_going', + `default_prompt` text not null default '', `prompt` text not null ) """) + # 检查sessions表是否存在`default_prompt`字段 + self.__execute__("PRAGMA table_info('sessions')") + columns = self.cursor.fetchall() + has_default_prompt = False + for field in columns: + if field[1] == 'default_prompt': + has_default_prompt = True + break + if not has_default_prompt: + self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''") + + self.__execute__(""" create table if not exists `account_fee`( `id` INTEGER PRIMARY KEY AUTOINCREMENT, @@ -75,7 +88,7 @@ class DatabaseManager: # session持久化 def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, - last_interact_timestamp: int, prompt: str): + last_interact_timestamp: int, prompt: str, default_prompt: str = ''): """持久化指定session""" # 检查是否已经有了此name和create_timestamp的session @@ -88,13 +101,13 @@ class DatabaseManager: if count == 0: sql = """ - insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`) - values (?, ?, ?, ?, ?, ?) + insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`) + values (?, ?, ?, ?, ?, ?, ?) """ self.__execute__(sql, ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, - last_interact_timestamp, prompt)) + last_interact_timestamp, prompt, default_prompt)) else: sql = """ update `sessions` set `last_interact_timestamp` = ?, `prompt` = ? @@ -126,7 +139,7 @@ class DatabaseManager: # 从数据库中加载所有还没过期的session config = pkg.utils.context.get_config() self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` from `sessions` where `last_interact_timestamp` > {} """.format(int(time.time()) - config.session_expire_time)) results = self.cursor.fetchall() @@ -139,6 +152,7 @@ class DatabaseManager: last_interact_timestamp = result[4] prompt = result[5] status = result[6] + default_prompt = result[7] # 当且仅当最后一个该对象的会话是on_going状态时,才会被加载 if status == 'on_going': @@ -147,7 +161,8 @@ class DatabaseManager: 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt + 'prompt': prompt, + 'default_prompt': default_prompt } else: if session_name in sessions: @@ -159,7 +174,7 @@ class DatabaseManager: def last_session(self, session_name: str, cursor_timestamp: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc limit 1 """.format(session_name, cursor_timestamp)) @@ -175,20 +190,22 @@ class DatabaseManager: last_interact_timestamp = result[4] prompt = result[5] status = result[6] + default_prompt = result[7] return { 'subject_type': subject_type, 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt + 'prompt': prompt, + 'default_prompt': default_prompt } # 获取此session_name后一个session的数据 def next_session(self, session_name: str, cursor_timestamp: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc limit 1 """.format(session_name, cursor_timestamp)) @@ -204,19 +221,21 @@ class DatabaseManager: last_interact_timestamp = result[4] prompt = result[5] status = result[6] + default_prompt = result[7] return { 'subject_type': subject_type, 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt + 'prompt': prompt, + 'default_prompt': default_prompt } # 列出与某个对象的所有对话session def list_history(self, session_name: str, capacity: int, page: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} """.format(session_name, capacity, capacity * page)) results = self.cursor.fetchall() @@ -229,13 +248,15 @@ class DatabaseManager: last_interact_timestamp = result[4] prompt = result[5] status = result[6] + default_prompt = result[7] sessions.append({ 'subject_type': subject_type, 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt + 'prompt': prompt, + 'default_prompt': default_prompt }) return sessions diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 0d627d16..3afe3990 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -75,6 +75,8 @@ def load_sessions(): except Exception: temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt']) temp_session.persistence() + temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \ + session_data[session_name]['default_prompt'] else [] sessions[session_name] = temp_session @@ -104,6 +106,9 @@ class Session: prompt = [] """使用list来保存会话中的回合""" + default_prompt = [] + """本session的默认prompt""" + create_timestamp = 0 """会话创建时间""" @@ -145,8 +150,8 @@ class Session: self.response_lock = threading.Lock() - self.prompt = self.get_default_prompt() - logging.debug("prompt is: {}".format(self.prompt)) + self.default_prompt = self.get_default_prompt() + logging.debug("prompt is: {}".format(self.default_prompt)) # 设定检查session最后一次对话是否超过过期时间的计时器 def schedule(self): @@ -190,11 +195,11 @@ class Session: self.last_interact_timestamp = int(time.time()) # 触发插件事件 - if self.prompt == self.get_default_prompt(): + if not self.prompt: args = { 'session_name': self.name, 'session': self, - 'default_prompt': self.prompt, + 'default_prompt': self.default_prompt, } event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args) @@ -212,7 +217,6 @@ class Session: # 成功获取,处理回复 res_test = message res_ans = res_test - # 去除开头可能的提示 res_ans_spt = res_test.split("\n\n") @@ -220,7 +224,6 @@ class Session: del (res_ans_spt[0]) res_ans = '\n\n'.join(res_ans_spt) - # 将此次对话的双方内容加入到prompt中 self.prompt.append({'role': 'user', 'content': text}) self.prompt.append({'role': 'assistant', 'content': res_ans}) @@ -249,25 +252,29 @@ class Session: def cut_out(self, msg: str, max_tokens: int) -> list: """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens""" # 如果用户消息长度超过max_tokens,直接返回 - - temp_prompt = [ + temp_prompt: list = [] + temp_prompt += self.default_prompt + temp_prompt.append( { 'role': 'user', 'content': msg } - ] + ) + + token_count = 0 + for item in temp_prompt: + token_count += len(item['content']) - token_count = len(msg) # 倒序遍历prompt for i in range(len(self.prompt) - 1, -1, -1): if token_count >= max_tokens: break - # 将prompt加到temp_prompt头部 - temp_prompt.insert(0, self.prompt[i]) + # 将prompt加到temp_prompt倒数第二个位置 + temp_prompt.insert(len(self.default_prompt), self.prompt[i]) token_count += len(self.prompt[i]['content']) - logging.debug('cut_out: {}'.format(str(temp_prompt))) + logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4))) return temp_prompt @@ -284,11 +291,11 @@ class Session: subject_number = int(name_spt[1]) db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, - json.dumps(self.prompt)) + json.dumps(self.prompt), json.dumps(self.default_prompt)) # 重置session def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None): - if self.prompt[-1]['role'] != "system": + if self.prompt: self.persistence() if explicit: # 触发插件事件 @@ -305,7 +312,8 @@ class Session: if expired: pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp) - self.prompt = self.get_default_prompt(use_prompt) + self.default_prompt = self.get_default_prompt(use_prompt) + self.prompt = [] self.create_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time()) self.just_switched_to_exist_session = False @@ -334,6 +342,7 @@ class Session: except json.decoder.JSONDecodeError: self.prompt = reset_session_prompt(self.name, last_one['prompt']) self.persistence() + self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else [] self.just_switched_to_exist_session = True return self @@ -353,6 +362,7 @@ class Session: except json.decoder.JSONDecodeError: self.prompt = reset_session_prompt(self.name, next_one['prompt']) self.persistence() + self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else [] self.just_switched_to_exist_session = True return self