diff --git a/.gitignore b/.gitignore index 88bac91f..645c403b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ __pycache__/ database.db qchatgpt.log config.py -banlist.py \ No newline at end of file +banlist.py +plugins/ +!plugins/__init__.py \ No newline at end of file diff --git a/README.md b/README.md index 17ed53d5..40f21ada 100644 --- a/README.md +++ b/README.md @@ -99,13 +99,26 @@ python3 main.py - 如提示安装`uvicorn`或`hypercorn`请*不要*安装,这两个不是必需的,目前存在未知原因bug - 如报错`TypeError: As of 3.10, the *loop* parameter was removed from Lock() since it is no longer necessary`, 请参考 [此处](https://github.com/RockChinQ/QChatGPT/issues/5) - ## 🚀使用 查看[Wiki功能使用页](https://github.com/RockChinQ/QChatGPT/wiki/%E5%8A%9F%E8%83%BD%E4%BD%BF%E7%94%A8#%E4%BD%BF%E7%94%A8%E6%96%B9%E5%BC%8F) +## 🧩插件生态 + +现已支持自行开发插件对功能进行扩展或自定义程序行为 +详见[Wiki插件使用页](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8) +开发教程见[Wiki插件开发页](https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E5%BC%80%E5%8F%91) + +### 示例插件 + +在`tests/plugin_examples`目录下,将其整个目录复制到`plugins`目录下即可使用 + +- `cmdcn` - 主程序指令中文形式 +- `hello_plugin` - 在收到消息`hello`时回复相应消息 +- `urlikethisijustsix` - 收到冒犯性消息时回复相应消息 + ## 👍赞赏 赞赏码 \ No newline at end of file diff --git a/main.py b/main.py index a1d99bf4..5c831ae1 100644 --- a/main.py +++ b/main.py @@ -39,6 +39,38 @@ def init_db(): known_exception_caught = False +def reset_logging(): + assert os.path.exists('config.py') + + config = importlib.import_module('config') + + import pkg.utils.context + + if pkg.utils.context.context['logger_handler'] is not None: + logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) + + for handler in logging.getLogger().handlers: + logging.getLogger().removeHandler(handler) + + logging.basicConfig(level=config.logging_level, # 设置日志输出格式 + filename='qchatgpt.log', # log日志输出的文件位置和文件名 + format="[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s", + # 日志输出的格式 + # -8表示占位符,让输出左对齐,输出长度都为8位 + datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式 + ) + sh = logging.StreamHandler() + sh.setLevel(config.logging_level) + sh.setFormatter(colorlog.ColoredFormatter( + fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : " + "%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + log_colors=log_colors_config + )) + logging.getLogger().addHandler(sh) + return sh + + def main(first_time_init=False): global known_exception_caught @@ -52,25 +84,7 @@ def main(first_time_init=False): import pkg.utils.context pkg.utils.context.set_config(config) - if pkg.utils.context.context['logger_handler'] is not None: - logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) - - logging.basicConfig(level=config.logging_level, # 设置日志输出格式 - filename='qchatgpt.log', # log日志输出的文件位置和文件名 - format="[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s", - # 日志输出的格式 - # -8表示占位符,让输出左对齐,输出长度都为8位 - datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式 - ) - sh = logging.StreamHandler() - sh.setLevel(config.logging_level) - sh.setFormatter(colorlog.ColoredFormatter( - fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : " - "%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - log_colors=log_colors_config - )) - logging.getLogger().addHandler(sh) + sh = reset_logging() # 检查是否设置了管理员 if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): @@ -117,9 +131,16 @@ def main(first_time_init=False): timeout=config.process_message_timeout, retry=config.retry_times, first_time_init=first_time_init) + # 加载插件 + import pkg.plugin.host + pkg.plugin.host.load_plugins() + + pkg.plugin.host.initialize_plugins() + if first_time_init: # 不是热重载之后的启动,则不启动新的bot线程 import mirai.exceptions + def run_bot_wrapper(): global known_exception_caught try: @@ -155,7 +176,7 @@ def main(first_time_init=False): known_exception_caught = True else: logging.error( - "捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/issues 查找或提issue".format(e)) + "捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/QChatGPT/issues 查找或提issue".format(e)) known_exception_caught = True raise e @@ -201,6 +222,9 @@ def stop(): import pkg.qqbot.manager import pkg.openai.session try: + import pkg.plugin.host + pkg.plugin.host.unload_plugins() + qqbot_inst = pkg.utils.context.get_qqbot_manager() assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager) @@ -230,6 +254,12 @@ if __name__ == '__main__': elif len(sys.argv) > 1 and sys.argv[1] == 'update': try: + try: + import pkg.utils.pkgmgr + pkg.utils.pkgmgr.ensure_dulwich() + except: + pass + from dulwich import porcelain repo = porcelain.open_repo('.') diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py index 56093844..6e0ea3ed 100644 --- a/pkg/audit/gatherer.py +++ b/pkg/audit/gatherer.py @@ -40,6 +40,9 @@ class DataGatherer: except: return + def get_usage(self, key_md5): + return self.usage[key_md5] if key_md5 in self.usage else {} + def report_text_model_usage(self, model, text): key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index ed89a680..78162fa9 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -2,10 +2,8 @@ import hashlib import logging -import pkg.database.manager -import pkg.qqbot.manager -import pkg.utils.context - +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models class KeysManager: api_key = {} @@ -39,13 +37,10 @@ class KeysManager: elif type(api_key) is list: for i in range(len(api_key)): self.api_key[str(i)] = api_key[i] - - self.auto_switch() # 从usage中删除未加载的api-key的记录 # 不删了,也许会运行时添加曾经有记录的api-key - if 'exceeded_keys' in pkg.utils.context.context and pkg.utils.context.context['exceeded_keys'] is not None: - self.exceeded = pkg.utils.context.context['exceeded_keys'] + self.auto_switch() # 根据tested自动切换到可用的api-key # 返回是否切换成功, 切换后的api-key的别名 @@ -53,7 +48,16 @@ class KeysManager: for key_name in self.api_key: if self.api_key[key_name] not in self.exceeded: self.using_key = self.api_key[key_name] + logging.info("使用api-key:" + key_name) + + # 触发插件事件 + args = { + "key_name": key_name, + "key_list": self.api_key.keys() + } + _ = plugin_host.emit(plugin_models.KeySwitched, **args) + return True, key_name self.using_key = list(self.api_key.values())[0] diff --git a/pkg/openai/session.py b/pkg/openai/session.py index b773197a..ad073dd0 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -6,6 +6,9 @@ import pkg.openai.manager import pkg.database.manager import pkg.utils.context +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models + # 运行时保存的所有session sessions = {} @@ -120,6 +123,17 @@ class Session: config = pkg.utils.context.get_config() if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: logging.info('session {} 已过期'.format(self.name)) + + # 触发插件事件 + args = { + 'session_name': self.name, + 'session': self, + 'session_expire_time': config.session_expire_time + } + event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args) + if event.is_prevented_default(): + return + self.reset(expired=True, schedule_new=False) # 删除此session @@ -131,6 +145,18 @@ class Session: def append(self, text: str) -> str: self.last_interact_timestamp = int(time.time()) + # 触发插件事件 + if self.prompt == self.get_default_prompt(): + args = { + 'session_name': self.name, + 'session': self, + 'default_prompt': self.prompt, + } + + event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args) + if event.is_prevented_default(): + return None + # max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7 config = pkg.utils.context.get_config() max_rounds = 1000 # 不再限制回合数 @@ -220,6 +246,15 @@ class Session: if self.prompt != self.get_default_prompt(): self.persistence() if explicit: + # 触发插件事件 + args = { + 'session_name': self.name, + 'session': self + } + + # 此事件不支持阻止默认行为 + _ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args) + pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) if expired: diff --git a/pkg/plugin/__init__.py b/pkg/plugin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py new file mode 100644 index 00000000..fea638aa --- /dev/null +++ b/pkg/plugin/host.py @@ -0,0 +1,273 @@ +# 插件管理模块 +import asyncio +import logging +import importlib +import os +import pkgutil +import sys +import traceback + +import pkg.utils.context as context +import pkg.plugin.switch as switch + +from mirai import Mirai + +__plugins__ = {} +""" +插件列表 + +示例: +{ + "example": { + "path": "plugins/example/main.py", + "enabled: True, + "name": "example", + "description": "example", + "version": "0.0.1", + "author": "RockChinQ", + "class": , + "hooks": { + "person_message": [ + + ] + }, + "instance": None + } +}""" + + +__current_module_path__ = "" + + +def walk_plugin_path(module, prefix='', path_prefix=''): + global __current_module_path__ + """遍历插件路径""" + for item in pkgutil.iter_modules(module.__path__): + if item.ispkg: + logging.debug("扫描插件包: plugins/{}".format(path_prefix + item.name)) + walk_plugin_path(__import__(module.__name__ + '.' + item.name, fromlist=['']), + prefix + item.name + '.', path_prefix + item.name + '/') + else: + logging.debug("扫描插件模块: plugins/{}".format(path_prefix + item.name + '.py')) + logging.info('加载模块: plugins/{}'.format(path_prefix + item.name + '.py')) + __current_module_path__ = "plugins/"+path_prefix + item.name + '.py' + + importlib.import_module(module.__name__ + '.' + item.name) + + +def load_plugins(): + """ 加载插件 """ + logging.info("加载插件") + PluginHost() + walk_plugin_path(__import__('plugins')) + + logging.debug(__plugins__) + + # 加载开关数据 + switch.load_switch() + + +def initialize_plugins(): + """ 初始化插件 """ + logging.info("初始化插件") + for plugin in __plugins__.values(): + if not plugin['enabled']: + continue + try: + plugin['instance'] = plugin["class"](plugin_host=context.get_plugin_host()) + except: + logging.error("插件{}初始化时发生错误: {}".format(plugin['name'], sys.exc_info())) + + +def unload_plugins(): + """ 卸载插件 """ + for plugin in __plugins__.values(): + if plugin['enabled'] and plugin['instance'] is not None: + if not hasattr(plugin['instance'], '__del__'): + logging.warning("插件{}没有定义析构函数".format(plugin['name'])) + else: + try: + plugin['instance'].__del__() + logging.info("卸载插件: {}".format(plugin['name'])) + except: + logging.error("插件{}卸载时发生错误: {}".format(plugin['name'], sys.exc_info())) + + +def install_plugin(repo_url: str): + """ 安装插件,从git储存库获取并解决依赖 """ + try: + import pkg.utils.pkgmgr + pkg.utils.pkgmgr.ensure_dulwich() + except: + pass + + try: + import dulwich + except ModuleNotFoundError: + raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") + + from dulwich import porcelain + + logging.info("克隆插件储存库: {}".format(repo_url)) + repo = porcelain.clone(repo_url, "plugins/"+repo_url.split(".git")[0].split("/")[-1]+"/", checkout=True) + + # 检查此目录是否包含requirements.txt + if os.path.exists("plugins/"+repo_url.split(".git")[0].split("/")[-1]+"/requirements.txt"): + logging.info("检测到requirements.txt,正在安装依赖") + import pkg.utils.pkgmgr + pkg.utils.pkgmgr.install_requirements("plugins/"+repo_url.split(".git")[0].split("/")[-1]+"/requirements.txt") + + import main + main.reset_logging() + + +class EventContext: + """ 事件上下文 """ + eid = 0 + """事件编号""" + + name = "" + + __prevent_default__ = False + """ 是否阻止默认行为 """ + + __prevent_postorder__ = False + """ 是否阻止后续插件的执行 """ + + __return_value__ = {} + """ 返回值 + 示例: + { + "example": [ + 'value1', + 'value2', + 3, + 4, + { + 'key1': 'value1', + }, + ['value1', 'value2'] + ] + } + """ + + def add_return(self, key: str, ret): + """添加返回值""" + if key not in self.__return_value__: + self.__return_value__[key] = [] + self.__return_value__[key].append(ret) + + def get_return(self, key: str): + """获取key的所有返回值""" + if key in self.__return_value__: + return self.__return_value__[key] + return None + + def get_return_value(self, key: str): + """获取key的首个返回值""" + if key in self.__return_value__: + return self.__return_value__[key][0] + return None + + def prevent_default(self): + """阻止默认行为""" + self.__prevent_default__ = True + + def prevent_postorder(self): + """阻止后续插件执行""" + self.__prevent_postorder__ = True + + def is_prevented_default(self): + """是否阻止默认行为""" + return self.__prevent_default__ + + def is_prevented_postorder(self): + """是否阻止后序插件执行""" + return self.__prevent_postorder__ + + def __init__(self, name: str): + self.name = name + self.eid = EventContext.eid + self.__prevent_default__ = False + self.__prevent_postorder__ = False + self.__return_value__ = {} + EventContext.eid += 1 + + +def emit(event_name: str, **kwargs) -> EventContext: + """ 触发事件 """ + import pkg.utils.context as context + if context.get_plugin_host() is None: + return None + return context.get_plugin_host().emit(event_name, **kwargs) + + +class PluginHost: + """插件宿主""" + + def __init__(self): + context.set_plugin_host(self) + + def get_runtime_context(self) -> context: + """获取运行时上下文(pkg.utils.context模块的对象) + + 此上下文用于和主程序其他模块交互(数据库、QQ机器人、OpenAI接口等) + 详见pkg.utils.context模块 + 其中的context变量保存了其他重要模块的类对象,可以使用这些对象进行交互 + """ + return context + + def get_bot(self) -> Mirai: + """获取机器人对象""" + return context.get_qqbot_manager().bot + + def send_person_message(self, person, message): + """发送私聊消息""" + asyncio.run(self.get_bot().send_friend_message(person, message)) + + def send_group_message(self, group, message): + """发送群消息""" + asyncio.run(self.get_bot().send_group_message(group, message)) + + def notify_admin(self, message): + """通知管理员""" + context.get_qqbot_manager().notify_admin(message) + + def emit(self, event_name: str, **kwargs) -> EventContext: + """ 触发事件 """ + event_context = EventContext(event_name) + logging.debug("触发事件: {} ({})".format(event_name, event_context.eid)) + for plugin in __plugins__.values(): + + if not plugin['enabled']: + continue + + if plugin['instance'] is None: + # 从关闭状态切到开启状态之后,重新加载插件 + try: + plugin['instance'] = plugin["class"]() + except: + logging.error("插件{}初始化时发生错误: {}".format(plugin['name'], sys.exc_info())) + continue + + for hook in plugin['hooks'].get(event_name, []): + try: + already_prevented_default = event_context.is_prevented_default() + + kwargs['host'] = context.get_plugin_host() + kwargs['event'] = event_context + + hook(plugin['instance'], **kwargs) + + if event_context.is_prevented_default() and not already_prevented_default: + logging.debug("插件 {} 已要求阻止事件 {} 的默认行为".format(plugin['name'], event_name)) + + if event_context.is_prevented_postorder(): + logging.debug("插件 {} 阻止了后序插件的执行".format(plugin['name'])) + break + + except Exception as e: + logging.error("插件{}触发事件{}时发生错误".format(plugin['name'], event_name)) + logging.error(traceback.format_exc()) + + return event_context diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py new file mode 100644 index 00000000..2c0e35a6 --- /dev/null +++ b/pkg/plugin/models.py @@ -0,0 +1,219 @@ +import logging + +import pkg.plugin.host as host +import pkg.utils.context + +PersonMessageReceived = "person_message_received" +"""收到私聊消息时,在判断是否应该响应前触发 + kwargs: + launcher_type: str 发起对象类型(group/person) + launcher_id: int 发起对象ID(群号/QQ号) + sender_id: int 发送者ID(QQ号) + message_chain: mirai.models.message.MessageChain 消息链 +""" + +GroupMessageReceived = "group_message_received" +"""收到群聊消息时,在判断是否应该响应前触发(所有群消息) + kwargs: + launcher_type: str 发起对象类型(group/person) + launcher_id: int 发起对象ID(群号/QQ号) + sender_id: int 发送者ID(QQ号) + message_chain: mirai.models.message.MessageChain 消息链 +""" + +PersonNormalMessageReceived = "person_normal_message_received" +"""判断为应该处理的私聊普通消息时触发 + kwargs: + launcher_type: str 发起对象类型(group/person) + launcher_id: int 发起对象ID(群号/QQ号) + sender_id: int 发送者ID(QQ号) + text_message: str 消息文本 + + returns (optional): + alter: str 修改后的消息文本 + reply: list 回复消息组件列表 +""" + +PersonCommandSent = "person_command_sent" +"""判断为应该处理的私聊指令时触发 + kwargs: + launcher_type: str 发起对象类型(group/person) + launcher_id: int 发起对象ID(群号/QQ号) + sender_id: int 发送者ID(QQ号) + command: str 指令 + params: list[str] 参数列表 + text_message: str 完整指令文本 + is_admin: bool 是否为管理员 + + returns (optional): + alter: str 修改后的完整指令文本 + reply: list 回复消息组件列表 +""" + +GroupNormalMessageReceived = "group_normal_message_received" +"""判断为应该处理的群聊普通消息时触发 + kwargs: + launcher_type: str 发起对象类型(group/person) + launcher_id: int 发起对象ID(群号/QQ号) + sender_id: int 发送者ID(QQ号) + text_message: str 消息文本 + + returns (optional): + alter: str 修改后的消息文本 + reply: list 回复消息组件列表 +""" + +GroupCommandSent = "group_command_sent" +"""判断为应该处理的群聊指令时触发 + kwargs: + launcher_type: str 发起对象类型(group/person) + launcher_id: int 发起对象ID(群号/QQ号) + sender_id: int 发送者ID(QQ号) + command: str 指令 + params: list[str] 参数列表 + text_message: str 完整指令文本 + is_admin: bool 是否为管理员 + + returns (optional): + alter: str 修改后的完整指令文本 + reply: list 回复消息组件列表 +""" + +NormalMessageResponded = "normal_message_responded" +"""获取到对普通消息的文字响应时触发 + kwargs: + launcher_type: str 发起对象类型(group/person) + launcher_id: int 发起对象ID(群号/QQ号) + sender_id: int 发送者ID(QQ号) + session: pkg.openai.session.Session 会话对象 + prefix: str 回复文字消息的前缀 + response_text: str 响应文本 + + returns (optional): + prefix: str 修改后的回复文字消息的前缀 + reply: list 替换回复消息组件列表 +""" + +SessionFirstMessageReceived = "session_first_message_received" +"""会话被第一次交互时触发 + kwargs: + session_name: str 会话名称(_) + session: pkg.openai.session.Session 会话对象 + default_prompt: str 预设值 +""" + +SessionExplicitReset = "session_reset" +"""会话被用户手动重置时触发,此事件不支持阻止默认行为 + kwargs: + session_name: str 会话名称(_) + session: pkg.openai.session.Session 会话对象 +""" + +SessionExpired = "session_expired" +"""会话过期时触发 + kwargs: + session_name: str 会话名称(_) + session: pkg.openai.session.Session 会话对象 + session_expire_time: int 已设置的会话过期时间(秒) +""" + +KeyExceeded = "key_exceeded" +"""api-key超额时触发 + kwargs: + key_name: str 超额的api-key名称 + usage: dict 超额的api-key使用情况 + exceeded_keys: list[str] 超额的api-key列表 +""" + +KeySwitched = "key_switched" +"""api-key超额切换成功时触发,此事件不支持阻止默认行为 + kwargs: + key_name: str 切换成功的api-key名称 + key_list: list[str] api-key列表 +""" + + +def on(event: str): + """注册事件监听器 + :param + event: str 事件名称 + """ + return Plugin.on(event) + + +__current_registering_plugin__ = "" + + +class Plugin: + + host: host.PluginHost + """插件宿主,提供插件的一些基础功能""" + + @classmethod + def on(cls, event): + """事件处理器装饰器 + + :param + event: 事件类型 + :return: + None + """ + + def wrapper(func): + plugin_hooks = host.__plugins__[__current_registering_plugin__]["hooks"] + + if event not in plugin_hooks: + plugin_hooks[event] = [] + plugin_hooks[event].append(func) + + host.__plugins__[__current_registering_plugin__]["hooks"] = plugin_hooks + + return func + + return wrapper + + +def register(name: str, description: str, version: str, author: str): + """注册插件, 此函数作为装饰器使用 + + Args: + name (str): 插件名称 + description (str): 插件描述 + version (str): 插件版本 + author (str): 插件作者 + + Returns: + None + """ + global __current_registering_plugin__ + + __current_registering_plugin__ = name + + host.__plugins__[name] = { + "name": name, + "description": description, + "version": version, + "author": author, + "hooks": {}, + "path": host.__current_module_path__, + "enabled": True, + "instance": None, + } + + def wrapper(cls: Plugin): + cls.name = name + cls.description = description + cls.version = version + cls.author = author + cls.host = pkg.utils.context.get_plugin_host() + cls.enabled = True + cls.path = host.__current_module_path__ + + # 存到插件列表 + host.__plugins__[name]["class"] = cls + + logging.info("插件注册完成: n='{}', d='{}', v={}, a='{}' ({})".format(name, description, version, author, cls)) + + return cls + + return wrapper diff --git a/pkg/plugin/switch.py b/pkg/plugin/switch.py new file mode 100644 index 00000000..ea3441fa --- /dev/null +++ b/pkg/plugin/switch.py @@ -0,0 +1,87 @@ +# 控制插件的开关 +import json +import logging +import os + +import pkg.plugin.host as host + + +def wrapper_dict_from_plugin_list() -> dict: + """ 将插件列表转换为开关json """ + switch = {} + + for plugin_name in host.__plugins__: + plugin = host.__plugins__[plugin_name] + + switch[plugin_name] = { + "path": plugin["path"], + "enabled": plugin["enabled"], + } + + return switch + + +def apply_switch(switch: dict): + """将开关数据应用到插件列表中""" + for plugin_name in switch: + host.__plugins__[plugin_name]["enabled"] = switch[plugin_name]["enabled"] + + +def dump_switch(): + """ 保存开关数据 """ + logging.debug("保存开关数据") + # 将开关数据写入plugins/switch.json + + switch = wrapper_dict_from_plugin_list() + + with open("plugins/switch.json", "w", encoding="utf-8") as f: + json.dump(switch, f, indent=4, ensure_ascii=False) + + +def load_switch(): + """ 加载开关数据 """ + logging.debug("加载开关数据") + # 读取plugins/switch.json + + switch = {} + + # 检查文件是否存在 + if not os.path.exists("plugins/switch.json"): + # 不存在则创建 + with open("plugins/switch.json", "w", encoding="utf-8") as f: + json.dump(switch, f, indent=4, ensure_ascii=False) + + with open("plugins/switch.json", "r", encoding="utf-8") as f: + switch = json.load(f) + + if switch is None: + switch = {} + + switch_modified = False + + switch_copy = switch.copy() + # 检查switch中多余的和path不相符的 + for plugin_name in switch_copy: + if plugin_name not in host.__plugins__: + del switch[plugin_name] + switch_modified = True + elif switch[plugin_name]["path"] != host.__plugins__[plugin_name]["path"]: + # 删除此不相符的 + del switch[plugin_name] + switch_modified = True + + # 检查plugin中多余的 + for plugin_name in host.__plugins__: + if plugin_name not in switch: + switch[plugin_name] = { + "path": host.__plugins__[plugin_name]["path"], + "enabled": host.__plugins__[plugin_name]["enabled"], + } + switch_modified = True + + # 如果switch有修改,保存 + if switch_modified: + dump_switch() + + # 应用开关数据 + apply_switch(switch) diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py new file mode 100644 index 00000000..507856c1 --- /dev/null +++ b/pkg/qqbot/command.py @@ -0,0 +1,313 @@ +# 指令处理模块 +import logging +import json +import datetime +import os +import threading + +import pkg.openai.session +import pkg.openai.manager +import pkg.utils.reloader +import pkg.utils.updater +import pkg.utils.context +import pkg.qqbot.message + +from mirai import Image + + +def config_operation(cmd, params): + reply = [] + config = pkg.utils.context.get_config() + reply_str = "" + if len(params) == 0: + reply = ["[bot]err:请输入配置项"] + else: + cfg_name = params[0] + if cfg_name == 'all': + reply_str = "[bot]所有配置项:\n\n" + for cfg in dir(config): + if not cfg.startswith('__') and not cfg == 'logging': + # 根据配置项类型进行格式化,如果是字典则转换为json并格式化 + if isinstance(getattr(config, cfg), str): + reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg)) + elif isinstance(getattr(config, cfg), dict): + # 不进行unicode转义,并格式化 + reply_str += "{}: {}\n".format(cfg, + json.dumps(getattr(config, cfg), + ensure_ascii=False, indent=4)) + else: + reply_str += "{}: {}\n".format(cfg, getattr(config, cfg)) + reply = [reply_str] + elif cfg_name in dir(config): + if len(params) == 1: + # 按照配置项类型进行格式化 + if isinstance(getattr(config, cfg_name), str): + reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name)) + elif isinstance(getattr(config, cfg_name), dict): + reply_str = "[bot]配置项{}: {}\n".format(cfg_name, + json.dumps(getattr(config, cfg_name), + ensure_ascii=False, indent=4)) + else: + reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name)) + reply = [reply_str] + else: + cfg_value = " ".join(params[1:]) + # 类型转换,如果是json则转换为字典 + if cfg_value == 'true': + cfg_value = True + elif cfg_value == 'false': + cfg_value = False + elif cfg_value.isdigit(): + cfg_value = int(cfg_value) + elif cfg_value.startswith('{') and cfg_value.endswith('}'): + cfg_value = json.loads(cfg_value) + else: + try: + cfg_value = float(cfg_value) + except ValueError: + pass + + # 检查类型是否匹配 + if isinstance(getattr(config, cfg_name), type(cfg_value)): + setattr(config, cfg_name, cfg_value) + pkg.utils.context.set_config(config) + reply = ["[bot]配置项{}修改成功".format(cfg_name)] + else: + reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)] + + else: + reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] + + return reply + + +def plugin_operation(cmd, params, is_admin): + reply = [] + + import pkg.plugin.host as plugin_host + import pkg.utils.updater as updater + + plugin_list = plugin_host.__plugins__ + + if len(params) == 0: + reply_str = "[bot]所有插件({}):\n\n".format(len(plugin_list)) + idx = 0 + for key in plugin_list: + plugin = plugin_list[key] + reply_str += "#{} {}:\n{}\nv{}\n作者: {}\n".format((idx+1), plugin['name'], plugin['description'], + plugin['version'], plugin['author']) + + if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): + reply_str += "源码: "+updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))+"\n" + + reply_str += "\n" + + idx += 1 + + reply = [reply_str] + elif params[0] == 'update': + # 更新所有插件 + if is_admin: + def closure(): + import pkg.utils.context + updated = [] + for key in plugin_list: + plugin = plugin_list[key] + if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): + success = updater.pull_latest("/".join(plugin['path'].split('/')[:-1])) + if success: + updated.append(plugin['name']) + + # 检查是否有requirements.txt + pkg.utils.context.get_qqbot_manager().notify_admin("正在安装依赖...") + for key in plugin_list: + plugin = plugin_list[key] + if os.path.exists("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt"): + logging.info("{}检测到requirements.txt,安装依赖".format(plugin['name'])) + import pkg.utils.pkgmgr + pkg.utils.pkgmgr.install_requirements("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt") + + import main + main.reset_logging() + + pkg.utils.context.get_qqbot_manager().notify_admin("[bot]已更新插件: {}".format(", ".join(updated))) + + threading.Thread(target=closure).start() + reply = ["[bot]正在更新所有插件,请勿重复发起..."] + else: + reply = ["[bot]err:权限不足"] + elif params[0].startswith("http"): + if is_admin: + + def closure(): + try: + plugin_host.install_plugin(params[0]) + pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 指令重载插件") + except Exception as e: + logging.error("插件安装失败:{}".format(e)) + pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e)) + + threading.Thread(target=closure, args=()).start() + reply = ["[bot]正在安装插件..."] + else: + reply = ["[bot]err:权限不足,请使用管理员账号私聊发起"] + return reply + + +def process_command(session_name: str, text_message: str, mgr, config, + launcher_type: str, launcher_id: int, sender_id: int) -> list: + reply = [] + try: + logging.info( + "[{}]发起指令:{}".format(session_name, text_message[:min(20, len(text_message))] + ( + "..." if len(text_message) > 20 else ""))) + + cmd = text_message[1:].strip().split(' ')[0] + + params = text_message[1:].strip().split(' ')[1:] + if cmd == 'help': + reply = ["[bot]" + config.help_message] + elif cmd == 'reset': + pkg.openai.session.get_session(session_name).reset(explicit=True) + reply = ["[bot]会话已重置"] + elif cmd == 'last': + result = pkg.openai.session.get_session(session_name).last_session() + if result is None: + reply = ["[bot]没有前一次的对话"] + else: + datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( + '%Y-%m-%d %H:%M:%S') + reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format( + datetime_str) + result.prompt[ + :min(100, + len(result.prompt))] + \ + ("..." if len(result.prompt) > 100 else "#END#")] + elif cmd == 'next': + result = pkg.openai.session.get_session(session_name).next_session() + if result is None: + reply = ["[bot]没有后一次的对话"] + else: + datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( + '%Y-%m-%d %H:%M:%S') + reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format( + datetime_str) + result.prompt[ + :min(100, + len(result.prompt))] + \ + ("..." if len(result.prompt) > 100 else "#END#")] + elif cmd == 'prompt': + reply = ["[bot]当前对话所有内容:\n" + pkg.openai.session.get_session(session_name).prompt] + elif cmd == 'list': + pkg.openai.session.get_session(session_name).persistence() + page = 0 + + if len(params) > 0: + try: + page = int(params[0]) + except ValueError: + pass + + results = pkg.openai.session.get_session(session_name).list_history(page=page) + if len(results) == 0: + reply = ["[bot]第{}页没有历史会话".format(page)] + else: + reply_str = "[bot]历史会话 第{}页:\n".format(page) + current = -1 + for i in range(len(results)): + # 时间(使用create_timestamp转换) 序号 部分内容 + datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp']) + reply_str += "#{} 创建:{} {}\n".format(i + page * 10, + datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), + results[i]['prompt'][ + :min(20, len(results[i]['prompt']))]) + if results[i]['create_timestamp'] == pkg.openai.session.get_session( + session_name).create_timestamp: + current = i + page * 10 + + reply_str += "\n以上信息倒序排列" + if current != -1: + reply_str += ",当前会话是 #{}\n".format(current) + else: + reply_str += ",当前处于全新会话或不在此页" + + reply = [reply_str] + elif cmd == 'resend': + session = pkg.openai.session.get_session(session_name) + to_send = session.undo() + + reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config, + launcher_type, launcher_id, sender_id) + elif cmd == 'usage': + reply_str = "[bot]各api-key使用情况:\n\n" + + api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key + for key_name in api_keys: + text_length = pkg.utils.context.get_openai_manager().audit_mgr \ + .get_text_length_of_key(api_keys[key_name]) + image_count = pkg.utils.context.get_openai_manager().audit_mgr \ + .get_image_count_of_key(api_keys[key_name]) + reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length), + int(image_count)) + + reply = [reply_str] + elif cmd == 'draw': + if len(params) == 0: + reply = ["[bot]err:请输入图片描述文字"] + else: + session = pkg.openai.session.get_session(session_name) + + res = session.draw_image(" ".join(params)) + + logging.debug("draw_image result:{}".format(res)) + reply = [Image(url=res['data'][0]['url'])] + if not (hasattr(config, 'include_image_description') + and not config.include_image_description): + reply.append(" ".join(params)) + elif cmd == 'version': + reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info()) + try: + if pkg.utils.updater.is_new_version_available(): + reply_str += "\n有新版本可用,请使用命令 !update 进行更新" + except: + pass + + reply = [reply_str] + + elif cmd == 'plugin': + reply = plugin_operation(cmd, params, True + if (launcher_type == 'person' and launcher_id == config.admin_qq) + else False) + elif cmd == 'reload' and launcher_type == 'person' and launcher_id == config.admin_qq: + def reload_task(): + pkg.utils.reloader.reload_all() + + threading.Thread(target=reload_task, daemon=True).start() + elif cmd == 'update' and launcher_type == 'person' and launcher_id == config.admin_qq: + def update_task(): + try: + if pkg.utils.updater.update_all(): + pkg.utils.reloader.reload_all(notify=False) + pkg.utils.context.get_qqbot_manager().notify_admin("更新完成") + else: + pkg.utils.context.get_qqbot_manager().notify_admin("无新版本") + except Exception as e0: + pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0)) + return + + threading.Thread(target=update_task, daemon=True).start() + + reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."] + elif cmd == 'cfg' and launcher_type == 'person' and launcher_id == config.admin_qq: + reply = config_operation(cmd, params) + else: + if cmd.startswith("~") and launcher_type == 'person' and launcher_id == config.admin_qq: + config_item = cmd[1:] + params = [config_item] + params + reply = config_operation("cfg", params) + else: + reply = ["[bot]err:未知的指令或权限不足: " + cmd] + except Exception as e: + mgr.notify_admin("{}指令执行失败:{}".format(session_name, e)) + logging.exception(e) + reply = ["[bot]err:{}".format(e)] + + return reply diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 95dd382b..40b02eb4 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -16,6 +16,9 @@ import pkg.qqbot.filter import pkg.qqbot.process as processor import pkg.utils.context +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models + # 并行运行 def go(func, args=()): @@ -51,7 +54,7 @@ def check_response_rule(text: str) -> (bool, str): class QQBotManager: retry = 3 - bot = None + bot: Mirai = None reply_filter = None @@ -95,15 +98,64 @@ class QQBotManager: # Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码 @self.bot.on(FriendMessage) async def on_friend_message(event: FriendMessage): - go(self.on_person_message, (event,)) + + def friend_message_handler(event: FriendMessage): + + # 触发事件 + args = { + "launcher_type": "person", + "launcher_id": event.sender.id, + "sender_id": event.sender.id, + "message_chain": event.message_chain, + } + plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) + + if plugin_event.is_prevented_default(): + return + + self.on_person_message(event) + + go(friend_message_handler, (event,)) @self.bot.on(StrangerMessage) async def on_stranger_message(event: StrangerMessage): - go(self.on_person_message, (event,)) + + def stranger_message_handler(event: StrangerMessage): + # 触发事件 + args = { + "launcher_type": "person", + "launcher_id": event.sender.id, + "sender_id": event.sender.id, + "message_chain": event.message_chain, + } + plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args) + + if plugin_event.is_prevented_default(): + return + + self.on_person_message(event) + + go(stranger_message_handler, (event,)) @self.bot.on(GroupMessage) async def on_group_message(event: GroupMessage): - go(self.on_group_message, (event,)) + + def group_message_handler(event: GroupMessage): + # 触发事件 + args = { + "launcher_type": "group", + "launcher_id": event.group.id, + "sender_id": event.sender.id, + "message_chain": event.message_chain, + } + plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args) + + if plugin_event.is_prevented_default(): + return + + self.on_group_message(event) + + go(group_message_handler, (event,)) def unsubscribe_all(): """取消所有订阅 @@ -155,6 +207,7 @@ class QQBotManager: # 私聊消息处理 def on_person_message(self, event: MessageEvent): + reply = '' if event.sender.id == self.bot.qq: @@ -189,6 +242,7 @@ class QQBotManager: # 群消息处理 def on_group_message(self, event: GroupMessage): + reply = '' def process(text=None) -> str: diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py new file mode 100644 index 00000000..1a3c4596 --- /dev/null +++ b/pkg/qqbot/message.py @@ -0,0 +1,92 @@ +# 普通消息处理模块 +import logging +import openai +import pkg.utils.context +import pkg.openai.session + +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models + + +def process_normal_message(text_message: str, mgr, config, launcher_type: str, + launcher_id: int, sender_id: int) -> list: + session_name = f"{launcher_type}_{launcher_id}" + logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( + "..." if len(text_message) > 20 else ""))) + + session = pkg.openai.session.get_session(session_name) + + reply = [] + while True: + try: + prefix = "[GPT]" if hasattr(config, "show_prefix") and config.show_prefix else "" + + text = session.append(text_message) + + # 触发插件事件 + args = { + "launcher_type": launcher_type, + "launcher_id": launcher_id, + "sender_id": sender_id, + "session": session, + "prefix": prefix, + "response_text": text + } + + event = pkg.plugin.host.emit(plugin_models.NormalMessageResponded, **args) + + if event.get_return_value("prefix") is not None: + prefix = event.get_return_value("prefix") + + if event.get_return_value("reply") is not None: + reply = event.get_return_value("reply") + + if not event.is_prevented_default(): + reply = [prefix + text] + except openai.error.APIConnectionError as e: + mgr.notify_admin("{}会话调用API失败:{}".format(session_name, e)) + reply = ["[bot]err:调用API失败,请重试或联系作者,或等待修复"] + except openai.error.RateLimitError as e: + logging.debug(type(e)) + # 尝试切换api-key + current_key_name = pkg.utils.context.get_openai_manager().key_mgr.get_key_name( + pkg.utils.context.get_openai_manager().key_mgr.using_key + ) + pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded() + + # 触发插件事件 + args = { + 'key_name': current_key_name, + 'usage': pkg.utils.context.get_openai_manager().audit_mgr + .get_usage(pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()), + 'exceeded_keys': pkg.utils.context.get_openai_manager().key_mgr.exceeded, + } + event = plugin_host.emit(plugin_models.KeyExceeded, **args) + + if not event.is_prevented_default(): + switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch() + + if not switched: + mgr.notify_admin( + "api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key".format( + current_key_name)) + reply = ["[bot]err:API调用额度超额,请联系作者,或等待修复"] + else: + openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key() + mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name)) + reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] + continue + except openai.error.InvalidRequestError as e: + mgr.notify_admin("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或" + "completion_api_params中的max_tokens参数数值过大导致的,请尝试将其降低".format( + session_name, e)) + reply = ["[bot]err:API调用参数错误,请联系作者,或等待修复"] + except openai.error.ServiceUnavailableError as e: + # mgr.notify_admin("{}API调用服务不可用:{}".format(session_name, e)) + reply = ["[bot]err:API调用服务暂不可用,请尝试重试"] + except Exception as e: + logging.exception(e) + reply = ["[bot]err:{}".format(e)] + break + + return reply diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 9aabac9d..42f8faed 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,14 +1,10 @@ # 此模块提供了消息处理的具体逻辑的接口 import asyncio -import datetime -import json -import threading from func_timeout import func_set_timeout import logging -import openai -from mirai import Image, MessageChain, Plain +from mirai import MessageChain, Plain # 这里不使用动态引入config # 因为在这里动态引入会卡死程序 @@ -20,76 +16,15 @@ import pkg.openai.manager import pkg.utils.reloader import pkg.utils.updater import pkg.utils.context +import pkg.qqbot.message +import pkg.qqbot.command + +import pkg.plugin.host as plugin_host +import pkg.plugin.models as plugin_models processing = [] -def config_operation(cmd, params): - reply = [] - config = pkg.utils.context.get_config() - reply_str = "" - if len(params) == 0: - reply = ["[bot]err:请输入配置项"] - else: - cfg_name = params[0] - if cfg_name == 'all': - reply_str = "[bot]所有配置项:\n\n" - for cfg in dir(config): - if not cfg.startswith('__') and not cfg == 'logging': - # 根据配置项类型进行格式化,如果是字典则转换为json并格式化 - if isinstance(getattr(config, cfg), str): - reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg)) - elif isinstance(getattr(config, cfg), dict): - # 不进行unicode转义,并格式化 - reply_str += "{}: {}\n".format(cfg, - json.dumps(getattr(config, cfg), - ensure_ascii=False, indent=4)) - else: - reply_str += "{}: {}\n".format(cfg, getattr(config, cfg)) - reply = [reply_str] - elif cfg_name in dir(config): - if len(params) == 1: - # 按照配置项类型进行格式化 - if isinstance(getattr(config, cfg_name), str): - reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name)) - elif isinstance(getattr(config, cfg_name), dict): - reply_str = "[bot]配置项{}: {}\n".format(cfg_name, - json.dumps(getattr(config, cfg_name), - ensure_ascii=False, indent=4)) - else: - reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name)) - reply = [reply_str] - else: - cfg_value = " ".join(params[1:]) - # 类型转换,如果是json则转换为字典 - if cfg_value == 'true': - cfg_value = True - elif cfg_value == 'false': - cfg_value = False - elif cfg_value.isdigit(): - cfg_value = int(cfg_value) - elif cfg_value.startswith('{') and cfg_value.endswith('}'): - cfg_value = json.loads(cfg_value) - else: - try: - cfg_value = float(cfg_value) - except ValueError: - pass - - # 检查类型是否匹配 - if isinstance(getattr(config, cfg_name), type(cfg_value)): - setattr(config, cfg_name, cfg_value) - pkg.utils.context.set_config(config) - reply = ["[bot]配置项{}修改成功".format(cfg_name)] - else: - reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)] - - else: - reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] - - return reply - - @func_set_timeout(config_init_import.process_message_timeout) def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain, sender_id: int) -> MessageChain: @@ -120,210 +55,64 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes pkg.openai.session.get_session(session_name).acquire_response_lock() + # 处理消息 try: if session_name in processing: pkg.openai.session.get_session(session_name).release_response_lock() return MessageChain([Plain("[bot]err:正在处理中,请稍后再试")]) - processing.append(session_name) - config = pkg.utils.context.get_config() - is_message = True + processing.append(session_name) try: - if text_message.startswith('!') or text_message.startswith("!"): # 指令 - is_message = False - try: - logging.info( - "[{}]发起指令:{}".format(session_name, text_message[:min(20, len(text_message))] + ( - "..." if len(text_message) > 20 else ""))) + # 触发插件事件 + args = { + 'launcher_type': launcher_type, + 'launcher_id': launcher_id, + 'sender_id': sender_id, + 'command': text_message[1:].strip().split(' ')[0], + 'params': text_message[1:].strip().split(' ')[1:], + 'text_message': text_message, + 'is_admin': sender_id is config.admin_qq, + } + event = plugin_host.emit(plugin_models.PersonCommandSent + if launcher_type == 'person' + else plugin_models.GroupCommandSent, **args) - cmd = text_message[1:].strip().split(' ')[0] + if event.get_return_value("alter") is not None: + text_message = event.get_return_value("alter") - params = text_message[1:].strip().split(' ')[1:] - if cmd == 'help': - reply = ["[bot]" + config.help_message] - elif cmd == 'reset': - pkg.openai.session.get_session(session_name).reset(explicit=True) - reply = ["[bot]会话已重置"] - elif cmd == 'last': - result = pkg.openai.session.get_session(session_name).last_session() - if result is None: - reply = ["[bot]没有前一次的对话"] - else: - datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( - '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format( - datetime_str) + result.prompt[ - :min(100, - len(result.prompt))] + \ - ("..." if len(result.prompt) > 100 else "#END#")] - elif cmd == 'next': - result = pkg.openai.session.get_session(session_name).next_session() - if result is None: - reply = ["[bot]没有后一次的对话"] - else: - datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( - '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format( - datetime_str) + result.prompt[ - :min(100, - len(result.prompt))] + \ - ("..." if len(result.prompt) > 100 else "#END#")] - elif cmd == 'prompt': - reply = ["[bot]当前对话所有内容:\n" + pkg.openai.session.get_session(session_name).prompt] - elif cmd == 'list': - pkg.openai.session.get_session(session_name).persistence() - page = 0 + # 取出插件提交的返回值赋值给reply + if event.get_return_value("reply") is not None: + reply = event.get_return_value("reply") - if len(params) > 0: - try: - page = int(params[0]) - except ValueError: - pass + if not event.is_prevented_default(): + reply = pkg.qqbot.command.process_command(session_name, text_message, + mgr, config, launcher_type, launcher_id, sender_id) - results = pkg.openai.session.get_session(session_name).list_history(page=page) - if len(results) == 0: - reply = ["[bot]第{}页没有历史会话".format(page)] - else: - reply_str = "[bot]历史会话 第{}页:\n".format(page) - current = -1 - for i in range(len(results)): - # 时间(使用create_timestamp转换) 序号 部分内容 - datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp']) - reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - results[i]['prompt'][ - :min(20, len(results[i]['prompt']))]) - if results[i]['create_timestamp'] == pkg.openai.session.get_session( - session_name).create_timestamp: - current = i + page * 10 + else: # 消息 + # 触发插件事件 + args = { + "launcher_type": launcher_type, + "launcher_id": launcher_id, + "sender_id": sender_id, + "text_message": text_message, + } + event = plugin_host.emit(plugin_models.PersonNormalMessageReceived + if launcher_type == 'person' + else plugin_models.GroupNormalMessageReceived, **args) - reply_str += "\n以上信息倒序排列" - if current != -1: - reply_str += ",当前会话是 #{}\n".format(current) - else: - reply_str += ",当前处于全新会话或不在此页" + if event.get_return_value("alter") is not None: + text_message = event.get_return_value("alter") - reply = [reply_str] - elif cmd == 'resend': - session = pkg.openai.session.get_session(session_name) - to_send = session.undo() - text_message = to_send - is_message = True - elif cmd == 'usage': - reply_str = "[bot]各api-key使用情况:\n\n" + # 取出插件提交的返回值赋值给reply + if event.get_return_value("reply") is not None: + reply = event.get_return_value("reply") - api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key - for key_name in api_keys: - text_length = pkg.utils.context.get_openai_manager().audit_mgr \ - .get_text_length_of_key(api_keys[key_name]) - image_count = pkg.utils.context.get_openai_manager().audit_mgr \ - .get_image_count_of_key(api_keys[key_name]) - reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length), - int(image_count)) - - reply = [reply_str] - elif cmd == 'draw': - if len(params) == 0: - reply = ["[bot]err:请输入图片描述文字"] - else: - session = pkg.openai.session.get_session(session_name) - - res = session.draw_image(" ".join(params)) - - logging.debug("draw_image result:{}".format(res)) - reply = [Image(url=res['data'][0]['url'])] - if not (hasattr(config, 'include_image_description') - and not config.include_image_description): - reply.append(" ".join(params)) - elif cmd == 'version': - reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info()) - try: - if pkg.utils.updater.is_new_version_available(): - reply_str += "\n有新版本可用,请使用命令 !update 进行更新" - except: - pass - - reply = [reply_str] - elif cmd == 'reload' and launcher_type == 'person' and launcher_id == config.admin_qq: - def reload_task(): - pkg.utils.reloader.reload_all() - - threading.Thread(target=reload_task, daemon=True).start() - elif cmd == 'update' and launcher_type == 'person' and launcher_id == config.admin_qq: - def update_task(): - try: - if pkg.utils.updater.update_all(): - pkg.utils.reloader.reload_all(notify=False) - pkg.utils.context.get_qqbot_manager().notify_admin("更新完成") - else: - pkg.utils.context.get_qqbot_manager().notify_admin("无新版本") - except Exception as e0: - pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0)) - return - - threading.Thread(target=update_task, daemon=True).start() - - reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."] - elif cmd == 'cfg' and launcher_type == 'person' and launcher_id == config.admin_qq: - reply = config_operation(cmd, params) - else: - if cmd.startswith("~") and launcher_type == 'person' and launcher_id == config.admin_qq: - config_item = cmd[1:] - params = [config_item] + params - reply = config_operation("cfg", params) - else: - reply = ["[bot]err:未知的指令或权限不足: " + cmd] - except Exception as e: - mgr.notify_admin("{}指令执行失败:{}".format(session_name, e)) - logging.exception(e) - reply = ["[bot]err:{}".format(e)] - - if is_message: # 消息 - logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( - "..." if len(text_message) > 20 else ""))) - - session = pkg.openai.session.get_session(session_name) - - while True: - try: - prefix = "[GPT]" if hasattr(config, "show_prefix") and config.show_prefix else "" - reply = [prefix + session.append(text_message)] - except openai.error.APIConnectionError as e: - mgr.notify_admin("{}会话调用API失败:{}".format(session_name, e)) - reply = ["[bot]err:调用API失败,请重试或联系作者,或等待修复"] - except openai.error.RateLimitError as e: - logging.debug(type(e)) - # 尝试切换api-key - current_key_name = pkg.utils.context.get_openai_manager().key_mgr.get_key_name( - pkg.utils.context.get_openai_manager().key_mgr.using_key - ) - pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded() - switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch() - - if not switched: - mgr.notify_admin("api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key".format( - current_key_name)) - reply = ["[bot]err:API调用额度超额,请联系作者,或等待修复"] - else: - openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key() - mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name)) - reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] - continue - except openai.error.InvalidRequestError as e: - mgr.notify_admin("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或" - "completion_api_params中的max_tokens参数数值过大导致的,请尝试将其降低".format( - session_name, e)) - reply = ["[bot]err:API调用参数错误,请联系作者,或等待修复"] - except openai.error.ServiceUnavailableError as e: - # mgr.notify_admin("{}API调用服务不可用:{}".format(session_name, e)) - reply = ["[bot]err:API调用服务暂不可用,请尝试重试"] - except Exception as e: - logging.exception(e) - reply = ["[bot]err:{}".format(e)] - break + if not event.is_prevented_default(): + reply = pkg.qqbot.message.process_normal_message(text_message, + mgr, config, launcher_type, launcher_id, sender_id) if reply is not None and type(reply[0]) == str: logging.info( diff --git a/pkg/utils/context.py b/pkg/utils/context.py index f599e246..30f81770 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -6,6 +6,7 @@ context = { }, 'logger_handler': None, 'config': None, + 'plugin_host': None, } @@ -38,4 +39,12 @@ def set_qqbot_manager(inst): def get_qqbot_manager(): - return context['inst']['qqbot.manager.QQBotManager'] \ No newline at end of file + return context['inst']['qqbot.manager.QQBotManager'] + + +def set_plugin_host(inst): + context['plugin_host'] = inst + + +def get_plugin_host(): + return context['plugin_host'] diff --git a/pkg/utils/pkgmgr.py b/pkg/utils/pkgmgr.py new file mode 100644 index 00000000..590a2f7c --- /dev/null +++ b/pkg/utils/pkgmgr.py @@ -0,0 +1,31 @@ +from pip._internal import main as pipmain + + +def install(package): + pipmain(['install', package]) + + +def install_requirements(file): + pipmain(['install', '-r', file]) + + +def ensure_dulwich(): + # 尝试三次 + for i in range(3): + try: + import dulwich + return + except ImportError: + install('dulwich') + + raise ImportError("无法自动安装dulwich库") + + +if __name__ == "__main__": + try: + install("openai11") + except Exception as e: + print(111) + print(e) + + print(222) \ No newline at end of file diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py index 639c03da..7e436678 100644 --- a/pkg/utils/reloader.py +++ b/pkg/utils/reloader.py @@ -4,6 +4,7 @@ import threading import importlib import pkgutil import pkg.utils.context +import pkg.plugin.host def walk(module, prefix=''): @@ -34,6 +35,10 @@ def reload_all(notify=True): importlib.reload(__import__('banlist')) pkg.utils.context.context = context + # 重载插件 + import plugins + walk(plugins) + # 执行启动流程 logging.info("执行程序启动流程") threading.Thread(target=main.main, args=(False,), daemon=False).start() diff --git a/pkg/utils/updater.py b/pkg/utils/updater.py index 3dedc238..7c55dd7b 100644 --- a/pkg/utils/updater.py +++ b/pkg/utils/updater.py @@ -3,12 +3,35 @@ import datetime import pkg.utils.context -def update_all() -> bool: - """使用dulwich更新源码""" +def check_dulwich_closure(): + try: + import pkg.utils.pkgmgr + pkg.utils.pkgmgr.ensure_dulwich() + except: + pass + try: import dulwich except ModuleNotFoundError: raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") + + +def pull_latest(repo_path: str) -> bool: + """拉取最新代码""" + check_dulwich_closure() + + from dulwich import porcelain + + repo = porcelain.open_repo(repo_path) + porcelain.pull(repo) + + return True + + +def update_all() -> bool: + """使用dulwich更新源码""" + check_dulwich_closure() + import dulwich try: before_commit_id = get_current_commit_id() from dulwich import porcelain @@ -35,12 +58,30 @@ def update_all() -> bool: raise Exception("分支不一致,自动更新仅支持master分支,请手动更新(https://github.com/RockChinQ/QChatGPT/issues/76)") +def is_repo(path: str) -> bool: + """检查是否是git仓库""" + check_dulwich_closure() + + from dulwich import porcelain + try: + porcelain.open_repo(path) + return True + except: + return False + + +def get_remote_url(repo_path: str) -> str: + """获取远程仓库地址""" + check_dulwich_closure() + + from dulwich import porcelain + repo = porcelain.open_repo(repo_path) + return str(porcelain.get_remote_repo(repo, "origin")[1]) + + def get_current_version_info() -> str: """获取当前版本信息""" - try: - import dulwich - except ModuleNotFoundError: - raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") + check_dulwich_closure() from dulwich import porcelain @@ -62,10 +103,7 @@ def get_current_version_info() -> str: def get_commit_id_and_time_and_msg() -> str: """获取当前提交id和时间和提交信息""" - try: - import dulwich - except ModuleNotFoundError: - raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") + check_dulwich_closure() from dulwich import porcelain @@ -79,10 +117,7 @@ def get_commit_id_and_time_and_msg() -> str: def get_current_commit_id() -> str: """检查是否有新版本""" - try: - import dulwich - except ModuleNotFoundError: - raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") + check_dulwich_closure() from dulwich import porcelain @@ -97,10 +132,7 @@ def get_current_commit_id() -> str: def is_new_version_available() -> bool: """检查是否有新版本""" - try: - import dulwich - except ModuleNotFoundError: - raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") + check_dulwich_closure() from dulwich import porcelain diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100644 index 00000000..27b6735e --- /dev/null +++ b/plugins/__init__.py @@ -0,0 +1,14 @@ +# 在此处填写的插件仓库将会被自动下载并加载 +# 支持gitee和github仓库 +# 这种加载插件的方式是推荐的,便于插件的获取和更新 +# +# 示例: +# plugin_repos = [ +# 'https://github.com/SampleUser/SampleRepo', +# 'https://gitee.com/SampleUser/SampleRepo' +# ] + + +remote_repos = [ + +] diff --git a/res/plugin_hello_group.jpg b/res/plugin_hello_group.jpg new file mode 100644 index 00000000..74428ff4 Binary files /dev/null and b/res/plugin_hello_group.jpg differ diff --git a/res/plugin_hello_person.png b/res/plugin_hello_person.png new file mode 100644 index 00000000..2a2514ce Binary files /dev/null and b/res/plugin_hello_person.png differ diff --git a/tests/plugin_examples/__init__.py b/tests/plugin_examples/__init__.py new file mode 100644 index 00000000..b063f0ca --- /dev/null +++ b/tests/plugin_examples/__init__.py @@ -0,0 +1,3 @@ +# 插件示例 +# 将此目录下的目录放入plugins目录即可使用 +# 每个示例插件的功能请查看其包内的__init__.py或README.md diff --git a/tests/plugin_examples/cmdcn/__init__.py b/tests/plugin_examples/cmdcn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plugin_examples/cmdcn/cmdcn.py b/tests/plugin_examples/cmdcn/cmdcn.py new file mode 100644 index 00000000..788b0935 --- /dev/null +++ b/tests/plugin_examples/cmdcn/cmdcn.py @@ -0,0 +1,51 @@ +from pkg.plugin.models import * +from pkg.plugin.host import EventContext, PluginHost + +""" +基本命令的中文形式支持 +""" + + +__mapping__ = { + "帮助": "help", + "重置": "reset", + "前一次": "last", + "后一次": "next", + "会话内容": "prompt", + "列出会话": "list", + "重新回答": "resend", + "使用量": "usage", + "绘画": "draw", + "版本": "version", + "热重载": "reload", + "热更新": "update", + "配置": "cfg", +} + + +@register(name="CmdCN", description="命令中文支持", version="0.1", author="RockChinQ") +class CmdCnPlugin(Plugin): + + def __init__(self, plugin_host: PluginHost): + pass + + # 私聊发送指令 + @on(PersonCommandSent) + def person_command_sent(self, event: EventContext, **kwargs): + cmd = kwargs['command'] + if cmd in __mapping__: + + # 返回替换后的指令 + event.add_return("alter", "!"+__mapping__[cmd]+" "+" ".join(kwargs['params'])) + + # 群聊发送指令 + @on(GroupCommandSent) + def group_command_sent(self, event: EventContext, **kwargs): + cmd = kwargs['command'] + if cmd in __mapping__: + + # 返回替换后的指令 + event.add_return("alter", "!"+__mapping__[cmd]+" "+" ".join(kwargs['params'])) + + def __del__(self): + pass diff --git a/tests/plugin_examples/hello_plugin/__init__.py b/tests/plugin_examples/hello_plugin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plugin_examples/hello_plugin/main.py b/tests/plugin_examples/hello_plugin/main.py new file mode 100644 index 00000000..3a5ba8bb --- /dev/null +++ b/tests/plugin_examples/hello_plugin/main.py @@ -0,0 +1,50 @@ +from pkg.plugin.models import * +from pkg.plugin.host import EventContext, PluginHost + +""" +在收到私聊或群聊消息"hello"时,回复"hello, <发送者id>!"或"hello, everyone!" +""" + + +# 注册插件 +@register(name="Hello", description="hello world", version="0.1", author="RockChinQ") +class HelloPlugin(Plugin): + + # 插件加载时触发 + # plugin_host (pkg.plugin.host.PluginHost) 提供了与主程序交互的一些方法,详细请查看其源码 + def __init__(self, plugin_host: PluginHost): + pass + + # 当收到个人消息时触发 + @on(PersonNormalMessageReceived) + def person_normal_message_received(self, event: EventContext, **kwargs): + msg = kwargs['text_message'] + if msg == "hello": # 如果消息为hello + + # 输出调试信息 + logging.debug("hello, {}".format(kwargs['sender_id'])) + + # 回复消息 "hello, <发送者id>!" + event.add_return("reply", ["hello, {}!".format(kwargs['sender_id'])]) + + # 阻止该事件默认行为(向接口获取回复) + event.prevent_default() + + # 当收到群消息时触发 + @on(GroupNormalMessageReceived) + def group_normal_message_received(self, event: EventContext, **kwargs): + msg = kwargs['text_message'] + if msg == "hello": # 如果消息为hello + + # 输出调试信息 + logging.debug("hello, {}".format(kwargs['sender_id'])) + + # 回复消息 "hello, everyone!" + event.add_return("reply", ["hello, everyone!"]) + + # 阻止该事件默认行为(向接口获取回复) + event.prevent_default() + + # 插件卸载时触发 + def __del__(self): + pass diff --git a/tests/plugin_examples/urlikethisijustsix/__init__.py b/tests/plugin_examples/urlikethisijustsix/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plugin_examples/urlikethisijustsix/urlt.py b/tests/plugin_examples/urlikethisijustsix/urlt.py new file mode 100644 index 00000000..40cbc9da --- /dev/null +++ b/tests/plugin_examples/urlikethisijustsix/urlt.py @@ -0,0 +1,43 @@ +import random + +from mirai import Plain + +from pkg.plugin.models import * +from pkg.plugin.host import EventContext + +""" +私聊或群聊消息为以下列出的一些冒犯性词语时,自动回复__random_reply__中的一句话 +""" + + +__words__ = ['sb', "傻逼", "dinner", "操你妈", "cnm", "fuck you", "fuckyou", + "f*ck you", "弱智", "若智", "答辩", "依托答辩", "低能儿", "nt", "脑瘫", "闹谈", "老坛"] + +__random_reply__ = ['好好好', "啊对对对", "好好好好", "你说得对", "谢谢夸奖"] + + +@register(name="啊对对对", description="你都这样了,我就顺从你吧", version="0.1", author="RockChinQ") +class AdddPlugin(Plugin): + + def __init__(self, plugin_host: PluginHost): + pass + + # 绑定私聊消息事件和群消息事件 + @on(PersonNormalMessageReceived) + @on(GroupNormalMessageReceived) + def normal_message_received(self, event: EventContext, **kwargs): + msg = kwargs['text_message'] + + # 如果消息中包含关键词 + if msg in __words__: + # 随机一个回复 + idx = random.randint(0, len(__random_reply__)-1) + + # 返回回复的消息 + event.add_return("reply", [Plain(__random_reply__[idx])]) + + # 阻止向接口获取回复 + event.prevent_default() + + def __del__(self): + pass