diff --git a/.gitignore b/.gitignore index 362973b7..10cc092c 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,6 @@ prompts/ logs/ sensitive.json temp/ -current_tag \ No newline at end of file +current_tag +scenario/ +!scenario/default-template.json \ No newline at end of file diff --git a/config-template.py b/config-template.py index c1e7103c..88c969f8 100644 --- a/config-template.py +++ b/config-template.py @@ -79,6 +79,11 @@ default_prompt = { "default": "如果我之后想获取帮助,请你说“输入!help获取帮助”", } +# 实验性设置项: JSON完整情景导入 +# 预设prompt模式 +# 参考值:旧版本方式:default | 完整情景:full_scenario +preset_mode = "default" + # 群内响应规则 # 符合此消息的群内消息即使不包含at机器人也会响应 # 支持消息前缀匹配及正则表达式匹配 diff --git a/main.py b/main.py index 4d87c7a7..9f868d75 100644 --- a/main.py +++ b/main.py @@ -182,6 +182,7 @@ def main(first_time_init=False): import pkg.openai.dprompt pkg.openai.dprompt.read_prompt_from_file() + pkg.openai.dprompt.read_scenario_from_file() pkg.utils.context.context['logger_handler'] = sh # 主启动流程 @@ -337,6 +338,10 @@ if __name__ == '__main__': if not os.path.exists("sensitive.json"): shutil.copy("sensitive-template.json", "sensitive.json") + # 检查是否有scenario/default.json + if not os.path.exists("scenario/default.json"): + shutil.copy("scenario/default-template.json", "scenario/default.json") + # 检查temp目录 if not os.path.exists("temp/"): os.mkdir("temp/") diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 752664d0..999d7315 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -53,10 +53,23 @@ class DatabaseManager: `create_timestamp` bigint not null, `last_interact_timestamp` bigint not null, `status` varchar(255) not null default 'on_going', + `default_prompt` text not null default '', `prompt` text not null ) """) + # 检查sessions表是否存在`default_prompt`字段 + self.__execute__("PRAGMA table_info('sessions')") + columns = self.cursor.fetchall() + has_default_prompt = False + for field in columns: + if field[1] == 'default_prompt': + has_default_prompt = True + break + if not has_default_prompt: + self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''") + + self.__execute__(""" create table if not exists `account_fee`( `id` INTEGER PRIMARY KEY AUTOINCREMENT, @@ -76,7 +89,7 @@ class DatabaseManager: # session持久化 def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, - last_interact_timestamp: int, prompt: str): + last_interact_timestamp: int, prompt: str, default_prompt: str = ''): """持久化指定session""" # 检查是否已经有了此name和create_timestamp的session @@ -89,13 +102,13 @@ class DatabaseManager: if count == 0: sql = """ - insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`) - values (?, ?, ?, ?, ?, ?) + insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`) + values (?, ?, ?, ?, ?, ?, ?) """ self.__execute__(sql, ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, - last_interact_timestamp, prompt)) + last_interact_timestamp, prompt, default_prompt)) else: sql = """ update `sessions` set `last_interact_timestamp` = ?, `prompt` = ? @@ -127,7 +140,7 @@ class DatabaseManager: # 从数据库中加载所有还没过期的session config = pkg.utils.context.get_config() self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` from `sessions` where `last_interact_timestamp` > {} """.format(int(time.time()) - config.session_expire_time)) results = self.cursor.fetchall() @@ -140,6 +153,7 @@ class DatabaseManager: last_interact_timestamp = result[4] prompt = result[5] status = result[6] + default_prompt = result[7] # 当且仅当最后一个该对象的会话是on_going状态时,才会被加载 if status == 'on_going': @@ -148,7 +162,8 @@ class DatabaseManager: 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt + 'prompt': prompt, + 'default_prompt': default_prompt } else: if session_name in sessions: @@ -160,7 +175,7 @@ class DatabaseManager: def last_session(self, session_name: str, cursor_timestamp: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc limit 1 """.format(session_name, cursor_timestamp)) @@ -176,20 +191,22 @@ class DatabaseManager: last_interact_timestamp = result[4] prompt = result[5] status = result[6] + default_prompt = result[7] return { 'subject_type': subject_type, 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt + 'prompt': prompt, + 'default_prompt': default_prompt } # 获取此session_name后一个session的数据 def next_session(self, session_name: str, cursor_timestamp: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc limit 1 """.format(session_name, cursor_timestamp)) @@ -205,19 +222,21 @@ class DatabaseManager: last_interact_timestamp = result[4] prompt = result[5] status = result[6] + default_prompt = result[7] return { 'subject_type': subject_type, 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt + 'prompt': prompt, + 'default_prompt': default_prompt } # 列出与某个对象的所有对话session def list_history(self, session_name: str, capacity: int, page: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} """.format(session_name, capacity, capacity * page)) results = self.cursor.fetchall() @@ -230,13 +249,15 @@ class DatabaseManager: last_interact_timestamp = result[4] prompt = result[5] status = result[6] + default_prompt = result[7] sessions.append({ 'subject_type': subject_type, 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt + 'prompt': prompt, + 'default_prompt': default_prompt }) return sessions diff --git a/pkg/openai/dprompt.py b/pkg/openai/dprompt.py index 3aba31cb..84dc32fc 100644 --- a/pkg/openai/dprompt.py +++ b/pkg/openai/dprompt.py @@ -1,4 +1,6 @@ # 多情景预设值管理 +import json +import logging __current__ = "default" """当前默认使用的情景预设的名称 @@ -9,8 +11,10 @@ __current__ = "default" __prompts_from_files__ = {} """从文件中读取的情景预设值""" +__scenario_from_files__ = {} -def read_prompt_from_file() -> str: + +def read_prompt_from_file(): """从文件读取预设值""" # 读取prompts/目录下的所有文件,以文件名为键,文件内容为值 # 保存在__prompts_from_files__中 @@ -23,6 +27,19 @@ def read_prompt_from_file() -> str: __prompts_from_files__[file] = f.read() +def read_scenario_from_file(): + """从JSON文件读取情景预设""" + global __scenario_from_files__ + import os + + __scenario_from_files__ = {} + for file in os.listdir("scenario"): + if file == "default-template.json": + continue + with open(os.path.join("scenario", file), encoding="utf-8") as f: + __scenario_from_files__[file] = json.load(f) + + def get_prompt_dict() -> dict: """获取预设值字典""" import config @@ -65,15 +82,40 @@ def set_to_default(): __current__ = list(default_dict.keys())[0] -def get_prompt(name: str = None) -> str: +def get_prompt(name: str = None) -> list: + global __scenario_from_files__ + import config + preset_mode = config.preset_mode + """获取预设值""" if name is None: name = get_current() - default_dict = get_prompt_dict() + # JSON预设方式 + if preset_mode == 'full_scenario': + import os - for key in default_dict: - if key.lower().startswith(name.lower()): - return default_dict[key] + for key in __scenario_from_files__: + if key.lower().startswith(name.lower()): + logging.debug('成功加载情景预设从JSON文件: {}'.format(key)) + return __scenario_from_files__[key]['prompt'] + + # 默认预设方式 + elif preset_mode == 'default': - raise KeyError("未找到情景预设: " + name) + default_dict = get_prompt_dict() + + for key in default_dict: + if key.lower().startswith(name.lower()): + return [ + { + "role": "user", + "content": default_dict[key] + }, + { + "role": "assistant", + "content": "好的。" + } + ] + + raise KeyError("未找到默认情景预设: " + name) diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index 7127db8c..4bb82038 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -88,4 +88,4 @@ class KeysManager: for key_name in self.api_key: if self.api_key[key_name] == api_key: return key_name - return "" \ No newline at end of file + return "" diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 233b9dc0..56b9b328 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -75,6 +75,8 @@ def load_sessions(): except Exception: temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt']) temp_session.persistence() + temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \ + session_data[session_name]['default_prompt'] else [] sessions[session_name] = temp_session @@ -104,6 +106,9 @@ class Session: prompt = [] """使用list来保存会话中的回合""" + default_prompt = [] + """本session的默认prompt""" + create_timestamp = 0 """会话创建时间""" @@ -129,24 +134,13 @@ class Session: # 从配置文件获取会话预设信息 def get_default_prompt(self, use_default: str = None): - config = pkg.utils.context.get_config() - import pkg.openai.dprompt as dprompt if use_default is None: - current_default_prompt = dprompt.get_prompt(dprompt.get_current()) - else: - current_default_prompt = dprompt.get_prompt(use_default) + use_default = dprompt.get_current() - return [ - { - 'role': 'user', - 'content': current_default_prompt - }, { - 'role': 'assistant', - 'content': 'ok' - } - ] + current_default_prompt = dprompt.get_prompt(use_default) + return current_default_prompt def __init__(self, name: str): self.name = name @@ -155,7 +149,9 @@ class Session: self.schedule() self.response_lock = threading.Lock() - self.prompt = self.get_default_prompt() + + self.default_prompt = self.get_default_prompt() + logging.debug("prompt is: {}".format(self.default_prompt)) # 设定检查session最后一次对话是否超过过期时间的计时器 def schedule(self): @@ -199,11 +195,11 @@ class Session: self.last_interact_timestamp = int(time.time()) # 触发插件事件 - if self.prompt == self.get_default_prompt(): + if not self.prompt: args = { 'session_name': self.name, 'session': self, - 'default_prompt': self.prompt, + 'default_prompt': self.default_prompt, } event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args) @@ -256,25 +252,29 @@ class Session: def cut_out(self, msg: str, max_tokens: int) -> list: """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens""" # 如果用户消息长度超过max_tokens,直接返回 - - temp_prompt = [ + temp_prompt: list = [] + temp_prompt += self.default_prompt + temp_prompt.append( { 'role': 'user', 'content': msg } - ] + ) + + token_count = 0 + for item in temp_prompt: + token_count += len(item['content']) - token_count = len(msg) # 倒序遍历prompt for i in range(len(self.prompt) - 1, -1, -1): if token_count >= max_tokens: break - # 将prompt加到temp_prompt头部 - temp_prompt.insert(0, self.prompt[i]) + # 将prompt加到temp_prompt倒数第二个位置 + temp_prompt.insert(len(self.default_prompt), self.prompt[i]) token_count += len(self.prompt[i]['content']) - logging.debug('cut_out: {}'.format(str(temp_prompt))) + logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4))) return temp_prompt @@ -291,11 +291,11 @@ class Session: subject_number = int(name_spt[1]) db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, - json.dumps(self.prompt)) + json.dumps(self.prompt), json.dumps(self.default_prompt)) # 重置session def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None): - if self.prompt[-1]['role'] != "system": + if self.prompt: self.persistence() if explicit: # 触发插件事件 @@ -311,7 +311,9 @@ class Session: if expired: pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp) - self.prompt = self.get_default_prompt(use_prompt) + + self.default_prompt = self.get_default_prompt(use_prompt) + self.prompt = [] self.create_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time()) self.just_switched_to_exist_session = False @@ -340,6 +342,7 @@ class Session: except json.decoder.JSONDecodeError: self.prompt = reset_session_prompt(self.name, last_one['prompt']) self.persistence() + self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else [] self.just_switched_to_exist_session = True return self @@ -359,6 +362,7 @@ class Session: except json.decoder.JSONDecodeError: self.prompt = reset_session_prompt(self.name, next_one['prompt']) self.persistence() + self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else [] self.just_switched_to_exist_session = True return self diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py index f63ffbef..1699ce39 100644 --- a/pkg/qqbot/command.py +++ b/pkg/qqbot/command.py @@ -234,7 +234,7 @@ def process_command(session_name: str, text_message: str, mgr, config, if len(msg) >= 2: reply_str += "#{} 创建:{} {}\n".format(i + page * 10, datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - msg[1]['content']) + msg[0]['content']) else: reply_str += "#{} 创建:{} {}\n".format(i + page * 10, datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), diff --git a/scenario/default-template.json b/scenario/default-template.json new file mode 100644 index 00000000..d9b7267a --- /dev/null +++ b/scenario/default-template.json @@ -0,0 +1,12 @@ +{ + "prompt": [ + { + "role": "system", + "content": "You are a helpful assistant. 如果我需要帮助,你要说“输入!help获得帮助”" + }, + { + "role": "assistant", + "content": "好的,我是一个能干的AI助手。 如果你需要帮助,我会说“输入!help获得帮助”" + } + ] +} \ No newline at end of file