diff --git a/main.py b/main.py index baf339c6..6164bf5b 100644 --- a/main.py +++ b/main.py @@ -3,13 +3,14 @@ # QChatGPT/main.py asciiart = r""" - ___ ___ _ _ ___ ___ _____ - / _ \ / __| |_ __ _| |_ / __| _ \_ _| -| (_) | (__| ' \/ _` | _| (_ | _/ | | - \__\_\\___|_||_\__,_|\__|\___|_| |_| + _ ___ _ +| | __ _ _ _ __ _| _ ) ___| |_ +| |__/ _` | ' \/ _` | _ \/ _ \ _| +|____\__,_|_||_\__, |___/\___/\__| + |___/ -⭐️开源地址: https://github.com/RockChinQ/QChatGPT -📖文档地址: https://q.rkcn.top +⭐️开源地址: https://github.com/RockChinQ/LangBot +📖文档地址: https://docs.langbot.app """ diff --git a/pkg/api/http/controller/groups/plugins.py b/pkg/api/http/controller/groups/plugins.py index ef30f3a9..86ac7ec6 100644 --- a/pkg/api/http/controller/groups/plugins.py +++ b/pkg/api/http/controller/groups/plugins.py @@ -15,7 +15,7 @@ class PluginsRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> str: - plugins = self.ap.plugin_mgr.plugins + plugins = self.ap.plugin_mgr.plugins() plugins_data = [plugin.model_dump() for plugin in plugins] @@ -27,7 +27,7 @@ class PluginsRouterGroup(group.RouterGroup): async def _(author: str, plugin_name: str) -> str: data = await quart.request.json target_enabled = data.get('target_enabled') - await self.ap.plugin_mgr.update_plugin_status(plugin_name, target_enabled) + await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled) return self.success() @self.route('///update', methods=['POST']) diff --git a/pkg/api/http/controller/groups/system.py b/pkg/api/http/controller/groups/system.py index a967d6b1..3b9c57fa 100644 --- a/pkg/api/http/controller/groups/system.py +++ b/pkg/api/http/controller/groups/system.py @@ -39,3 +39,25 @@ class SystemRouterGroup(group.RouterGroup): return self.http_status(404, 404, "Task not found") return self.success(data=task.to_dict()) + + @self.route('/reload', methods=['POST']) + async def _() -> str: + json_data = await quart.request.json + + scope = json_data.get("scope") + + await self.ap.reload( + scope=scope + ) + return self.success() + + @self.route('/_debug/exec', methods=['POST']) + async def _() -> str: + if not constants.debug_mode: + return self.http_status(403, 403, "Forbidden") + + py_code = await quart.request.data + + ap = self.ap + + return self.success(data=exec(py_code, {"ap": ap})) diff --git a/pkg/api/http/controller/main.py b/pkg/api/http/controller/main.py index d91f9afe..8befea43 100644 --- a/pkg/api/http/controller/main.py +++ b/pkg/api/http/controller/main.py @@ -6,7 +6,7 @@ import os import quart import quart_cors -from ....core import app +from ....core import app, entities as core_entities from .groups import logs, system, settings, plugins, stats from . import group @@ -32,15 +32,26 @@ class HTTPController: while True: await asyncio.sleep(1) + async def exception_handler(*args, **kwargs): + try: + await self.quart_app.run_task( + *args, **kwargs + ) + except Exception as e: + self.ap.logger.error(f"启动 HTTP 服务失败: {e}") + self.ap.task_mgr.create_task( - self.quart_app.run_task( + exception_handler( host=self.ap.system_cfg.data["http-api"]["host"], port=self.ap.system_cfg.data["http-api"]["port"], shutdown_trigger=shutdown_trigger_placeholder, ), name="http-api-quart", + scopes=[core_entities.LifecycleControlScope.APPLICATION], ) + # await asyncio.sleep(5) + async def register_routes(self) -> None: @self.quart_app.route("/healthz") diff --git a/pkg/audit/center/apigroup.py b/pkg/audit/center/apigroup.py index 3e3c5eb5..4b20a09a 100644 --- a/pkg/audit/center/apigroup.py +++ b/pkg/audit/center/apigroup.py @@ -9,7 +9,7 @@ import asyncio import aiohttp import requests -from ...core import app +from ...core import app, entities as core_entities class APIGroup(metaclass=abc.ABCMeta): @@ -65,14 +65,12 @@ class APIGroup(metaclass=abc.ABCMeta): **kwargs, ) -> asyncio.Task: """执行请求""" - # task = asyncio.create_task(self._do(method, path, data, params, headers, **kwargs)) - - # self.ap.asyncio_tasks.append(task) return self.ap.task_mgr.create_task( self._do(method, path, data, params, headers, **kwargs), kind="telemetry-operation", name=f"{method} {path}", + scopes=[core_entities.LifecycleControlScope.APPLICATION], ).task def gen_rid(self): diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index 33031bfb..404813eb 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import AsyncGenerator from .. import operator, entities, cmdmgr +from ...plugin import context as plugin_context @operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') @@ -9,16 +10,18 @@ class FuncOperator(operator.CommandOperator): async def execute( self, context: entities.ExecuteContext ) -> AsyncGenerator[entities.CommandReturn, None]: - reply_str = "当前已加载的内容函数: \n\n" + reply_str = "当前已启用的内容函数: \n\n" index = 1 - all_functions = await self.ap.tool_mgr.get_all_functions() + all_functions = await self.ap.tool_mgr.get_all_functions( + plugin_enabled=True, + plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED, + ) for func in all_functions: - reply_str += "{}. {}{}:\n{}\n\n".format( + reply_str += "{}. {}:\n{}\n\n".format( index, - ("(已禁用) " if not func.enable else ""), func.name, func.description, ) diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py index 715c7726..e50d0ba2 100644 --- a/pkg/command/operators/plugin.py +++ b/pkg/command/operators/plugin.py @@ -18,7 +18,7 @@ class PluginOperator(operator.CommandOperator): context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - plugin_list = self.ap.plugin_mgr.plugins + plugin_list = self.ap.plugin_mgr.plugins() reply_str = "所有插件({}):\n".format(len(plugin_list)) idx = 0 for plugin in plugin_list: @@ -110,7 +110,7 @@ class PluginUpdateAllOperator(operator.CommandOperator): try: plugins = [ p.plugin_name - for p in self.ap.plugin_mgr.plugins + for p in self.ap.plugin_mgr.plugins() ] if plugins: @@ -182,7 +182,7 @@ class PluginEnableOperator(operator.CommandOperator): plugin_name = context.crt_params[0] try: - if await self.ap.plugin_mgr.update_plugin_status(plugin_name, True): + if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True): yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) else: yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) @@ -210,7 +210,7 @@ class PluginDisableOperator(operator.CommandOperator): plugin_name = context.crt_params[0] try: - if await self.ap.plugin_mgr.update_plugin_status(plugin_name, False): + if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False): yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) else: yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) diff --git a/pkg/core/app.py b/pkg/core/app.py index c96808a3..46e70775 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -4,6 +4,8 @@ import logging import asyncio import threading import traceback +import enum +import sys from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr @@ -21,8 +23,9 @@ from ..pipeline import controller, stagemgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..persistence import mgr as persistencemgr from ..api.http.controller import main as http_controller -from ..utils import logcache +from ..utils import logcache, ip from . import taskmgr +from . import entities as core_entities class Application: @@ -104,24 +107,84 @@ class Application: pass async def run(self): - await self.plugin_mgr.initialize_plugins() - try: - + await self.plugin_mgr.initialize_plugins() # 后续可能会允许动态重启其他任务 # 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程 async def never_ending(): while True: await asyncio.sleep(1) - self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager") - self.task_mgr.create_task(self.ctrl.run(), name="query-controller") - self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller") - self.task_mgr.create_task(never_ending(), name="never-ending-task") + self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]) + self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION]) + self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION]) + self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION]) + await self.print_web_access_info() await self.task_mgr.wait_all() except asyncio.CancelledError: pass except Exception as e: self.logger.error(f"应用运行致命异常: {e}") self.logger.debug(f"Traceback: {traceback.format_exc()}") + + async def print_web_access_info(self): + """打印访问 webui 的提示""" + import socket + + host_ip = socket.gethostbyname(socket.gethostname()) + + public_ip = await ip.get_myip() + + port = self.system_cfg.data['http-api']['port'] + + tips = f""" +======================================= +✨ 您可通过以下方式访问管理面板 + +🏠 本地地址:http://{host_ip}:{port}/ +🌐 公网地址:http://{public_ip}:{port}/ + +📌 如果您在容器中运行此程序,请确保容器的 {port} 端口已对外暴露 +🔗 若要使用公网地址访问,请阅读以下须知 + 1. 公网地址仅供参考,请以您的主机公网 IP 为准; + 2. 要使用公网地址访问,请确保您的主机具有公网 IP,并且系统防火墙已放行 {port} 端口; + +🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues +======================================= +""".strip() + for line in tips.split("\n"): + self.logger.info(line) + + async def reload( + self, + scope: core_entities.LifecycleControlScope, + ): + match scope: + case core_entities.LifecycleControlScope.PLATFORM.value: + self.logger.info("执行热重载 scope="+scope) + await self.platform_mgr.shutdown() + + self.platform_mgr = im_mgr.PlatformManager(self) + + await self.platform_mgr.initialize() + + self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]) + case core_entities.LifecycleControlScope.PLUGIN.value: + self.logger.info("执行热重载 scope="+scope) + await self.plugin_mgr.destroy_plugins() + + # 删除 sys.module 中所有的 plugins/* 下的模块 + for mod in list(sys.modules.keys()): + if mod.startswith("plugins."): + del sys.modules[mod] + + self.plugin_mgr = plugin_mgr.PluginManager(self) + await self.plugin_mgr.initialize() + + await self.plugin_mgr.initialize_plugins() + + await self.plugin_mgr.load_plugins() + await self.plugin_mgr.initialize_plugins() + case _: + pass diff --git a/pkg/core/boot.py b/pkg/core/boot.py index dff772d9..e6a0e3eb 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -53,13 +53,17 @@ async def main(loop: asyncio.AbstractEventLoop): # 挂系统信号处理 import signal + ap: app.Application + def signal_handler(sig, frame): print("[Signal] 程序退出.") + # ap.shutdown() os._exit(0) signal.signal(signal.SIGINT, signal_handler) app_inst = await make_app(loop) + ap = app_inst await app_inst.run() except Exception as e: traceback.print_exc() diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 3dd18c58..56938b08 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -6,7 +6,7 @@ required_deps = { "anthropic": "anthropic", "colorlog": "colorlog", "aiocqhttp": "aiocqhttp", - "botpy": "qq-botpy", + "botpy": "qq-botpy-rc", "PIL": "pillow", "nakuru": "nakuru-project-idk", "tiktoken": "tiktoken", diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 67b05666..464384be 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -17,6 +17,14 @@ from ..platform.types import events as platform_events from ..platform.types import entities as platform_entities + +class LifecycleControlScope(enum.Enum): + + APPLICATION = "application" + PLATFORM = "platform" + PLUGIN = "plugin" + + class LauncherTypes(enum.Enum): """一个请求的发起者类型""" diff --git a/pkg/core/taskmgr.py b/pkg/core/taskmgr.py index d7436b06..2c029c03 100644 --- a/pkg/core/taskmgr.py +++ b/pkg/core/taskmgr.py @@ -6,6 +6,7 @@ import datetime import traceback from . import app +from . import entities as core_entities class TaskContext: @@ -71,7 +72,7 @@ class TaskWrapper: task_type: str = "system" # 任务类型: system 或 user """任务类型""" - kind: str = "system_task" + kind: str = "system_task" # 由发起者确定任务种类,通常同质化的任务种类相同 """任务种类""" name: str = "" @@ -92,6 +93,9 @@ class TaskWrapper: ap: app.Application """应用实例""" + scopes: list[core_entities.LifecycleControlScope] + """任务所属生命周期控制范围""" + def __init__( self, ap: app.Application, @@ -101,6 +105,7 @@ class TaskWrapper: name: str = "", label: str = "", context: TaskContext = None, + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ): self.id = TaskWrapper._id_index TaskWrapper._id_index += 1 @@ -112,6 +117,7 @@ class TaskWrapper: self.name = name self.label = label if label != "" else name self.task.set_name(name) + self.scopes = scopes def assume_exception(self): try: @@ -145,6 +151,7 @@ class TaskWrapper: "kind": self.kind, "name": self.name, "label": self.label, + "scopes": [scope.value for scope in self.scopes], "task_context": self.task_context.to_dict(), "runtime": { "done": self.task.done(), @@ -154,6 +161,9 @@ class TaskWrapper: "result": self.assume_result().__str__() if self.assume_result() is not None else None, }, } + + def cancel(self): + self.task.cancel() class AsyncTaskManager: @@ -177,8 +187,9 @@ class AsyncTaskManager: name: str = "", label: str = "", context: TaskContext = None, + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ) -> TaskWrapper: - wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context) + wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes) self.tasks.append(wrapper) return wrapper @@ -189,8 +200,9 @@ class AsyncTaskManager: name: str = "", label: str = "", context: TaskContext = None, + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ) -> TaskWrapper: - return self.create_task(coro, "user", kind, name, label, context) + return self.create_task(coro, "user", kind, name, label, context, scopes) async def wait_all(self): await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True) @@ -214,3 +226,10 @@ class AsyncTaskManager: if t.id == id: return t return None + + def cancel_by_scope(self, scope: core_entities.LifecycleControlScope): + for wrapper in self.tasks: + + if not wrapper.task.done() and scope in wrapper.scopes: + + wrapper.task.cancel() diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index c5598a08..3113b3bf 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -4,7 +4,6 @@ import asyncio import typing import traceback - from ..core import app, entities from . import entities as pipeline_entities from ..plugin import events @@ -59,13 +58,11 @@ class Controller: (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() # 通知其他协程,有新的请求可以处理了 self.ap.query_pool.condition.notify_all() - - # task = asyncio.create_task(_process_query(selected_query)) - # self.ap.asyncio_tasks.append(task) self.ap.task_mgr.create_task( _process_query(selected_query), kind="query", name=f"query-{selected_query.query_id}", + scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM], ) except Exception as e: @@ -166,6 +163,23 @@ class Controller: async def process_query(self, query: entities.Query): """处理请求 """ + + # ======== 触发 MessageReceived 事件 ======== + event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_type( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + message_chain=query.message_chain, + query=query + ) + ) + + if event_ctx.is_prevented_default(): + return + self.ap.logger.debug(f"Processing query {query}") try: @@ -173,7 +187,6 @@ class Controller: except Exception as e: self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - # traceback.print_exc() finally: self.ap.logger.debug(f"Query {query} processed") diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 394771e5..6a391a1f 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -37,76 +37,40 @@ class PlatformManager: async def initialize(self): - from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy + from .sources import nakuru, aiocqhttp, qqbotpy async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter): - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.PersonMessageReceived( - launcher_type='person', - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_chain=event.message_chain, - query=None - ) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter ) - if not event_ctx.is_prevented_default(): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.PERSON, - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) - async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter): - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.PersonMessageReceived( - launcher_type='person', - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_chain=event.message_chain, - query=None - ) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter ) - if not event_ctx.is_prevented_default(): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.PERSON, - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) - async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter): - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.GroupMessageReceived( - launcher_type='group', - launcher_id=event.group.id, - sender_id=event.sender.id, - message_chain=event.message_chain, - query=None - ) + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.GROUP, + launcher_id=event.group.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter ) - - if not event_ctx.is_prevented_default(): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.GROUP, - launcher_id=event.group.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) index = 0 @@ -174,24 +138,30 @@ class PlatformManager: try: tasks = [] for adapter in self.adapters: - async def exception_wrapper(adapter): + async def exception_wrapper(adapter: msadapter.MessageSourceAdapter): try: await adapter.run_async() except Exception as e: + if isinstance(e, asyncio.CancelledError): + return self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") tasks.append(exception_wrapper(adapter)) for task in tasks: - # async_task = asyncio.create_task(task) - # self.ap.asyncio_tasks.append(async_task) self.ap.task_mgr.create_task( task, kind="platform-adapter", name=f"platform-adapter-{adapter.name}", + scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM], ) except Exception as e: self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + + async def shutdown(self): + for adapter in self.adapters: + await adapter.kill() + self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM) \ No newline at end of file diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 2fbe8be4..94993dc7 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -328,5 +328,5 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): while True: await asyncio.sleep(1) - def kill(self) -> bool: + async def kill(self) -> bool: return False \ No newline at end of file diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index cbc86f44..b10be34b 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -21,7 +21,6 @@ from ...platform.types import events as platform_events from ...platform.types import message as platform_message - class OfficialGroupMessage(platform_events.GroupMessage): pass @@ -588,8 +587,12 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter): self.member_openid_mapping, self.group_openid_mapping ) - self.ap.logger.info("运行 QQ 官方适配器") - await self.bot.start(**self.cfg) + self.cfg['ret_coro'] = True - def kill(self) -> bool: - return False + self.ap.logger.info("运行 QQ 官方适配器") + await (await self.bot.start(**self.cfg)) + + async def kill(self) -> bool: + if not self.bot.is_closed(): + await self.bot.close() + return True diff --git a/pkg/platform/sources/yirimirai.py b/pkg/platform/sources/yirimirai.py deleted file mode 100644 index aa0823fd..00000000 --- a/pkg/platform/sources/yirimirai.py +++ /dev/null @@ -1,121 +0,0 @@ -# import asyncio -# import typing - - -# from .. import adapter as adapter_model -# from ...core import app - - -# @adapter_model.adapter_class("yiri-mirai") -# class YiriMiraiAdapter(adapter_model.MessageSourceAdapter): -# """YiriMirai适配器""" -# bot: mirai.Mirai - -# def __init__(self, config: dict, ap: app.Application): -# """初始化YiriMirai的对象""" -# self.ap = ap -# self.config = config -# if 'adapter' not in config or \ -# config['adapter'] == 'WebSocketAdapter': -# self.bot = mirai.Mirai( -# qq=config['qq'], -# adapter=mirai.WebSocketAdapter( -# host=config['host'], -# port=config['port'], -# verify_key=config['verifyKey'] -# ) -# ) -# elif config['adapter'] == 'HTTPAdapter': -# self.bot = mirai.Mirai( -# qq=config['qq'], -# adapter=mirai.HTTPAdapter( -# host=config['host'], -# port=config['port'], -# verify_key=config['verifyKey'] -# ) -# ) -# else: -# raise Exception('Unknown adapter for YiriMirai: ' + config['adapter']) - -# async def send_message( -# self, -# target_type: str, -# target_id: str, -# message: mirai.MessageChain -# ): -# """发送消息 - -# Args: -# target_type (str): 目标类型,`person`或`group` -# target_id (str): 目标ID -# message (mirai.MessageChain): YiriMirai库的消息链 -# """ -# task = None -# if target_type == 'person': -# task = self.bot.send_friend_message(int(target_id), message) -# elif target_type == 'group': -# task = self.bot.send_group_message(int(target_id), message) -# else: -# raise Exception('Unknown target type: ' + target_type) - -# await task - -# async def reply_message( -# self, -# message_source: mirai.MessageEvent, -# message: mirai.MessageChain, -# quote_origin: bool = False -# ): -# """回复消息 - -# Args: -# message_source (mirai.MessageEvent): YiriMirai消息源事件 -# message (mirai.MessageChain): YiriMirai库的消息链 -# quote_origin (bool, optional): 是否引用原消息. Defaults to False. -# """ -# await self.bot.send(message_source, message, quote_origin) - -# async def is_muted(self, group_id: int) -> bool: -# result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get() -# if result.mute_time_remaining > 0: -# return True -# return False - -# def register_listener( -# self, -# event_type: typing.Type[mirai.Event], -# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] -# ): -# """注册事件监听器 - -# Args: -# event_type (typing.Type[mirai.Event]): YiriMirai事件类型 -# callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 -# """ -# async def wrapper(event: mirai.Event): -# await callback(event, self) -# self.bot.on(event_type)(wrapper) - -# def unregister_listener( -# self, -# event_type: typing.Type[mirai.Event], -# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] -# ): -# """注销事件监听器 - -# Args: -# event_type (typing.Type[mirai.Event]): YiriMirai事件类型 -# callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 -# """ -# assert isinstance(self.bot, mirai.Mirai) -# bus = self.bot.bus -# assert isinstance(bus, mirai.models.bus.ModelEventBus) - -# bus.unsubscribe(event_type, callback) - -# async def run_async(self): -# self.bot_account_id = self.bot.qq -# return await MiraiRunner(self.bot)._run() - -# async def kill(self) -> bool: -# return False diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index 46ffb4ac..8c9e4a06 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing import abc import pydantic +import enum from . import events from ..provider.tools import entities as tools_entities @@ -85,10 +86,19 @@ class BasePlugin(metaclass=abc.ABCMeta): """应用程序对象""" def __init__(self, host: APIHost): + """初始化阶段被调用""" self.host = host async def initialize(self): - """初始化插件""" + """初始化阶段被调用""" + pass + + async def destroy(self): + """释放/禁用插件时被调用""" + pass + + def __del__(self): + """释放/禁用插件时被调用""" pass @@ -247,6 +257,16 @@ class EventContext: EventContext.eid += 1 +class RuntimeContainerStatus(enum.Enum): + """插件容器状态""" + + MOUNTED = "mounted" + """已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态""" + + INITIALIZED = "initialized" + """已初始化""" + + class RuntimeContainer(pydantic.BaseModel): """运行时的插件容器 @@ -294,6 +314,9 @@ class RuntimeContainer(pydantic.BaseModel): content_functions: list[tools_entities.LLMFunction] = [] """内容函数""" + status: RuntimeContainerStatus = RuntimeContainerStatus.MOUNTED + """插件状态""" + class Config: arbitrary_types_allowed = True @@ -318,9 +341,6 @@ class RuntimeContainer(pydantic.BaseModel): self.priority = setting['priority'] self.enabled = setting['enabled'] - for function in self.content_functions: - function.enable = self.enabled - def model_dump(self, *args, **kwargs): return { 'name': self.plugin_name, @@ -342,9 +362,9 @@ class RuntimeContainer(pydantic.BaseModel): 'human_desc': function.human_desc, 'description': function.description, 'parameters': function.parameters, - 'enable': function.enable, 'func': function.func.__name__, } for function in self.content_functions ], + 'status': self.status.value, } diff --git a/pkg/plugin/loader.py b/pkg/plugin/loader.py index d5f4a20c..44ded4ac 100644 --- a/pkg/plugin/loader.py +++ b/pkg/plugin/loader.py @@ -13,13 +13,16 @@ class PluginLoader(metaclass=abc.ABCMeta): ap: app.Application + plugins: list[context.RuntimeContainer] + def __init__(self, ap: app.Application): self.ap = ap + self.plugins = [] async def initialize(self): pass @abc.abstractmethod - async def load_plugins(self) -> list[context.RuntimeContainer]: + async def load_plugins(self): pass diff --git a/pkg/plugin/loaders/classic.py b/pkg/plugin/loaders/classic.py index b2553c1a..b3710c9e 100644 --- a/pkg/plugin/loaders/classic.py +++ b/pkg/plugin/loaders/classic.py @@ -20,7 +20,14 @@ class PluginLoader(loader.PluginLoader): _current_container: context.RuntimeContainer = None - containers: list[context.RuntimeContainer] = [] + plugins: list[context.RuntimeContainer] = [] + + def __init__(self, ap): + self.ap = ap + self.plugins = [] + self._current_pkg_path = '' + self._current_module_path = '' + self._current_container = None async def initialize(self): """初始化""" @@ -77,8 +84,10 @@ class PluginLoader(loader.PluginLoader): } # 把 ctx.event 所有的属性都放到 args 里 - for k, v in ctx.event.dict().items(): - args[k] = v + # for k, v in ctx.event.dict().items(): + # args[k] = v + for attr_name in ctx.event.__dict__.keys(): + args[attr_name] = getattr(ctx.event, attr_name) func(plugin, **args) @@ -113,7 +122,6 @@ class PluginLoader(loader.PluginLoader): name=function_name, human_desc='', description=function_schema['description'], - enable=True, parameters=function_schema['parameters'], func=handler, ) @@ -153,7 +161,6 @@ class PluginLoader(loader.PluginLoader): name=function_name, human_desc='', description=function_schema['description'], - enable=True, parameters=function_schema['parameters'], func=func, ) @@ -189,15 +196,13 @@ class PluginLoader(loader.PluginLoader): importlib.import_module(module.__name__ + "." + item.name) if self._current_container is not None: - self.containers.append(self._current_container) + self.plugins.append(self._current_container) self.ap.logger.debug(f'插件 {self._current_container} 已加载') except: self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误') traceback.print_exc() - async def load_plugins(self) -> list[context.RuntimeContainer]: + async def load_plugins(self): """加载插件 """ await self._walk_plugin_path(__import__("plugins", fromlist=[""])) - - return self.containers diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py index ea9c4371..2b8e887d 100644 --- a/pkg/plugin/manager.py +++ b/pkg/plugin/manager.py @@ -22,7 +22,22 @@ class PluginManager: api_host: context.APIHost - plugins: list[context.RuntimeContainer] + def plugins( + self, + enabled: bool=None, + status: context.RuntimeContainerStatus=None, + ) -> list[context.RuntimeContainer]: + """获取插件列表 + """ + plugins = self.loader.plugins + + if enabled is not None: + plugins = [plugin for plugin in plugins if plugin.enabled == enabled] + + if status is not None: + plugins = [plugin for plugin in plugins if plugin.status == status] + + return plugins def __init__(self, ap: app.Application): self.ap = ap @@ -30,7 +45,6 @@ class PluginManager: self.installer = github.GitHubRepoInstaller(ap) self.setting = setting.SettingManager(ap) self.api_host = context.APIHost(ap) - self.plugins = [] async def initialize(self): await self.loader.initialize() @@ -41,27 +55,58 @@ class PluginManager: setattr(models, 'require_ver', self.api_host.require_ver) async def load_plugins(self): - self.plugins = await self.loader.load_plugins() + await self.loader.load_plugins() - await self.setting.sync_setting(self.plugins) + await self.setting.sync_setting(self.loader.plugins) # 按优先级倒序 - self.plugins.sort(key=lambda x: x.priority, reverse=True) + self.loader.plugins.sort(key=lambda x: x.priority, reverse=True) - self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugins}') + self.ap.logger.debug(f'优先级排序后的插件列表 {self.loader.plugins}') + + async def initialize_plugin(self, plugin: context.RuntimeContainer): + self.ap.logger.debug(f'初始化插件 {plugin.plugin_name}') + plugin.plugin_inst = plugin.plugin_class(self.api_host) + plugin.plugin_inst.ap = self.ap + plugin.plugin_inst.host = self.api_host + await plugin.plugin_inst.initialize() + plugin.status = context.RuntimeContainerStatus.INITIALIZED async def initialize_plugins(self): - for plugin in self.plugins: + for plugin in self.plugins(): + if not plugin.enabled: + self.ap.logger.debug(f'插件 {plugin.plugin_name} 未启用,跳过初始化') + continue try: - plugin.plugin_inst = plugin.plugin_class(self.api_host) - plugin.plugin_inst.ap = self.ap - plugin.plugin_inst.host = self.api_host - await plugin.plugin_inst.initialize() + await self.initialize_plugin(plugin) except Exception as e: self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}') self.ap.logger.exception(e) continue + async def destroy_plugin(self, plugin: context.RuntimeContainer): + if plugin.status != context.RuntimeContainerStatus.INITIALIZED: + return + + self.ap.logger.debug(f'释放插件 {plugin.plugin_name}') + plugin.plugin_inst.__del__() + await plugin.plugin_inst.destroy() + plugin.plugin_inst = None + plugin.status = context.RuntimeContainerStatus.MOUNTED + + async def destroy_plugins(self): + for plugin in self.plugins(): + if plugin.status != context.RuntimeContainerStatus.INITIALIZED: + self.ap.logger.debug(f'插件 {plugin.plugin_name} 未初始化,跳过释放') + continue + + try: + await self.destroy_plugin(plugin) + except Exception as e: + self.ap.logger.error(f'插件 {plugin.plugin_name} 释放失败: {e}') + self.ap.logger.exception(e) + continue + async def install_plugin( self, plugin_source: str, @@ -80,6 +125,9 @@ class PluginManager: } ) + task_context.trace('重载插件..', 'reload-plugin') + await self.ap.reload(scope='plugin') + async def uninstall_plugin( self, plugin_name: str, @@ -87,10 +135,15 @@ class PluginManager: ): """卸载插件 """ - await self.installer.uninstall_plugin(plugin_name, task_context) plugin_container = self.get_plugin_by_name(plugin_name) + if plugin_container is None: + raise ValueError(f'插件 {plugin_name} 不存在') + + await self.destroy_plugin(plugin_container) + await self.installer.uninstall_plugin(plugin_name, task_context) + await self.ap.ctr_mgr.plugin.post_remove_record( { "name": plugin_name, @@ -100,6 +153,9 @@ class PluginManager: } ) + task_context.trace('重载插件..', 'reload-plugin') + await self.ap.reload(scope='plugin') + async def update_plugin( self, plugin_name: str, @@ -123,11 +179,13 @@ class PluginManager: new_version="HEAD" ) + task_context.trace('重载插件..', 'reload-plugin') + await self.ap.reload(scope='plugin') def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer: """通过插件名获取插件 """ - for plugin in self.plugins: + for plugin in self.plugins(): if plugin.plugin_name == plugin_name: return plugin return None @@ -143,30 +201,32 @@ class PluginManager: emitted_plugins: list[context.RuntimeContainer] = [] - for plugin in self.plugins: - if plugin.enabled: - if event.__class__ in plugin.event_handlers: - self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}') - - is_prevented_default_before_call = ctx.is_prevented_default() + for plugin in self.plugins( + enabled=True, + status=context.RuntimeContainerStatus.INITIALIZED + ): + if event.__class__ in plugin.event_handlers: + self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}') + + is_prevented_default_before_call = ctx.is_prevented_default() - try: - await plugin.event_handlers[event.__class__]( - plugin.plugin_inst, - ctx - ) - except Exception as e: - self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}') - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - - emitted_plugins.append(plugin) + try: + await plugin.event_handlers[event.__class__]( + plugin.plugin_inst, + ctx + ) + except Exception as e: + self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}') + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + + emitted_plugins.append(plugin) - if not is_prevented_default_before_call and ctx.is_prevented_default(): - self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') + if not is_prevented_default_before_call and ctx.is_prevented_default(): + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') - if ctx.is_prevented_postorder(): - self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') - break + if ctx.is_prevented_postorder(): + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') + break for key in ctx.__return_value__.keys(): if hasattr(ctx.event, key): @@ -191,16 +251,22 @@ class PluginManager: return ctx - async def update_plugin_status(self, plugin_name: str, new_status: bool): + async def update_plugin_switch(self, plugin_name: str, new_status: bool): if self.get_plugin_by_name(plugin_name) is not None: - for plugin in self.plugins: + for plugin in self.plugins(): if plugin.plugin_name == plugin_name: - plugin.enabled = new_status - - for func in plugin.content_functions: - func.enable = new_status + if plugin.enabled == new_status: + return False - await self.setting.dump_container_setting(self.plugins) + # 初始化/释放插件 + if new_status: + await self.initialize_plugin(plugin) + else: + await self.destroy_plugin(plugin) + + plugin.enabled = new_status + + await self.setting.dump_container_setting(self.loader.plugins) break @@ -214,11 +280,11 @@ class PluginManager: plugin_name = plugin.get('name') plugin_priority = plugin.get('priority') - for plugin in self.plugins: + for plugin in self.loader.plugins: if plugin.plugin_name == plugin_name: plugin.priority = plugin_priority break - self.plugins.sort(key=lambda x: x.priority, reverse=True) + self.loader.plugins.sort(key=lambda x: x.priority, reverse=True) - await self.setting.dump_container_setting(self.plugins) + await self.setting.dump_container_setting(self.loader.plugins) diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index 9130cbe2..f328ffff 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from ...core import app, entities as core_entities +from ...plugin import context as plugin_context class SessionManager: @@ -51,7 +52,10 @@ class SessionManager: prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), messages=[], use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']), - use_funcs=await self.ap.tool_mgr.get_all_functions(), + use_funcs=await self.ap.tool_mgr.get_all_functions( + plugin_enabled=True, + plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED, + ), ) session.conversations.append(conversation) session.using_conversation = conversation diff --git a/pkg/provider/tools/entities.py b/pkg/provider/tools/entities.py index 52867291..8f09ab21 100644 --- a/pkg/provider/tools/entities.py +++ b/pkg/provider/tools/entities.py @@ -20,8 +20,6 @@ class LLMFunction(pydantic.BaseModel): description: str """给LLM识别的函数描述""" - enable: typing.Optional[bool] = True - parameters: dict func: typing.Callable diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 5e780c50..3e412eaf 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -20,28 +20,25 @@ class ToolManager: async def initialize(self): pass - async def get_function(self, name: str) -> entities.LLMFunction: - """获取函数""" - for function in await self.get_all_functions(): - if function.name == name: - return function - return None - async def get_function_and_plugin( self, name: str ) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: - """获取函数和插件""" - for plugin in self.ap.plugin_mgr.plugins: + """获取函数和插件实例""" + for plugin in self.ap.plugin_mgr.plugins( + enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED + ): for function in plugin.content_functions: if function.name == name: return function, plugin.plugin_inst return None, None - async def get_all_functions(self) -> list[entities.LLMFunction]: + async def get_all_functions(self, plugin_enabled: bool=None, plugin_status: plugin_context.RuntimeContainerStatus=None) -> list[entities.LLMFunction]: """获取所有函数""" all_functions: list[entities.LLMFunction] = [] - for plugin in self.ap.plugin_mgr.plugins: + for plugin in self.ap.plugin_mgr.plugins( + enabled=plugin_enabled, status=plugin_status + ): all_functions.extend(plugin.content_functions) return all_functions @@ -51,16 +48,15 @@ class ToolManager: tools = [] for function in use_funcs: - if function.enable: - function_schema = { - "type": "function", - "function": { - "name": function.name, - "description": function.description, - "parameters": function.parameters, - }, - } - tools.append(function_schema) + function_schema = { + "type": "function", + "function": { + "name": function.name, + "description": function.description, + "parameters": function.parameters, + }, + } + tools.append(function_schema) return tools @@ -92,13 +88,12 @@ class ToolManager: tools = [] for function in use_funcs: - if function.enable: - function_schema = { - "name": function.name, - "description": function.description, - "input_schema": function.parameters, - } - tools.append(function_schema) + function_schema = { + "name": function.name, + "description": function.description, + "input_schema": function.parameters, + } + tools.append(function_schema) return tools diff --git a/pkg/utils/ip.py b/pkg/utils/ip.py new file mode 100644 index 00000000..4f54bad2 --- /dev/null +++ b/pkg/utils/ip.py @@ -0,0 +1,9 @@ +import aiohttp + +async def get_myip() -> str: + try: + async with aiohttp.ClientSession() as session: + async with session.get("https://ip.useragentinfo.com/myip") as response: + return await response.text() + except Exception as e: + return '0.0.0.0' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cd55555d..7eaec08a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ openai>1.0.0 anthropic colorlog~=6.6.0 aiocqhttp -qq-botpy +qq-botpy-rc nakuru-project-idk Pillow tiktoken diff --git a/templates/platform.json b/templates/platform.json index f0a13fd6..74adf833 100644 --- a/templates/platform.json +++ b/templates/platform.json @@ -1,13 +1,5 @@ { "platform-adapters": [ - { - "adapter": "yiri-mirai", - "enable": false, - "host": "127.0.0.1", - "port": 8080, - "verifyKey": "yirimirai", - "qq": 123456789 - }, { "adapter": "nakuru", "enable": false, diff --git a/templates/schema/platform.json b/templates/schema/platform.json index 4c2bab31..cb42b798 100644 --- a/templates/schema/platform.json +++ b/templates/schema/platform.json @@ -9,43 +9,6 @@ "items": { "type": "object", "oneOf": [ - { - "title": "YiriMirai 适配器", - "description": "用于接入 Mirai", - "properties": { - "adapter": { - "type": "string", - "const": "yiri-mirai" - }, - "enable": { - "type": "boolean", - "default": false, - "description": "是否启用此适配器", - "layout": { - "comp": "switch", - "props": { - "color": "primary" - } - } - }, - "host": { - "type": "string", - "default": "127.0.0.1" - }, - "port": { - "type": "integer", - "default": 8080 - }, - "verifyKey": { - "type": "string", - "default": "yirimirai" - }, - "qq": { - "type": "integer", - "default": 123456789 - } - } - }, { "title": "Nakuru 适配器", "description": "用于接入 go-cqhttp", diff --git a/web/src/App.vue b/web/src/App.vue index 00424e70..05446fbf 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -67,6 +67,17 @@ + + + 重载消息平台 + + + + + + 重载插件 + + @@ -137,6 +148,30 @@ function openDocs() { window.open('https://docs.langbot.app', '_blank') } +const reloadScopeLabel = { + 'platform': "消息平台", + 'plugin': "插件" +} + +function reload(scope) { + let label = reloadScopeLabel[scope] + proxy.$axios.post('/system/reload', + { scope: scope }, + { headers: { 'Content-Type': 'application/json' } } + ).then(response => { + if (response.data.code === 0) { + success(label+'已重载') + + // 关闭菜单 + } else { + error(label+'重载失败:' + response.data.message) + } + }).catch(err => { + error(label+'重载失败:' + err) + }) + +} + const aboutDialogShow = ref(false) function showAboutDialog() { @@ -162,10 +197,6 @@ function closeAboutDialog() { margin-left: -0.2rem; } -#logo-img { - /* margin-left: -0.2rem; */ -} - #logo-list-item { margin-top: 0.5rem; margin-bottom: 1.5rem;