style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions

View File

@@ -1,4 +1,4 @@
"""插件支持包
包含插件基类、插件宿主以及部分API接口
"""
"""

View File

@@ -14,13 +14,10 @@ from ..platform import adapter as platform_adapter
def register(
name: str,
description: str,
version: str,
author: str
name: str, description: str, version: str, author: str
) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]:
"""注册插件类
使用示例:
@register(
@@ -34,15 +31,16 @@ def register(
"""
pass
def handler(
event: typing.Type[events.BaseEventModel]
event: typing.Type[events.BaseEventModel],
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件监听器
使用示例:
class MyPlugin(BasePlugin):
@handler(NormalMessageResponded)
async def on_normal_message_responded(self, ctx: EventContext):
pass
@@ -51,14 +49,14 @@ def handler(
def llm_func(
name: str=None,
name: str = None,
) -> typing.Callable:
"""注册内容函数
使用示例:
class MyPlugin(BasePlugin):
@llm_func("access_the_web_page")
async def _(self, query, url: str, brief_len: int):
\"""Call this function to search about the question before you answer any questions.
@@ -98,7 +96,7 @@ class BasePlugin(metaclass=abc.ABCMeta):
async def initialize(self):
"""初始化阶段被调用"""
pass
async def destroy(self):
"""释放/禁用插件时被调用"""
pass
@@ -123,12 +121,12 @@ class APIHost:
def get_platform_adapters(self) -> list[platform_adapter.MessagePlatformAdapter]:
"""获取已启用的消息平台适配器列表
Returns:
list[platform.adapter.MessageSourceAdapter]: 已启用的消息平台适配器列表
"""
return self.ap.platform_mgr.get_running_adapters()
async def send_active_message(
self,
adapter: platform_adapter.MessagePlatformAdapter,
@@ -137,7 +135,7 @@ class APIHost:
message: platform_message.MessageChain,
):
"""发送主动消息
Args:
adapter (platform.adapter.MessageSourceAdapter): 消息平台适配器对象,调用 host.get_platform_adapters() 获取并取用其中某个
target_type (str): 目标类型,`person`或`group`
@@ -153,7 +151,7 @@ class APIHost:
def require_ver(
self,
ge: str,
le: str='v999.999.999',
le: str = 'v999.999.999',
) -> bool:
"""插件版本要求装饰器
@@ -164,16 +162,23 @@ class APIHost:
Returns:
bool: 是否满足要求, False时为无法获取版本号True时为满足要求报错为不满足要求
"""
langbot_version = ""
langbot_version = ''
try:
langbot_version = self.ap.ver_mgr.get_current_version() # 从updater模块获取版本号
except:
langbot_version = (
self.ap.ver_mgr.get_current_version()
) # 从updater模块获取版本号
except Exception:
return False
if self.ap.ver_mgr.compare_version_str(langbot_version, ge) < 0 or \
(self.ap.ver_mgr.compare_version_str(langbot_version, le) > 0):
raise Exception("LangBot 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{}".format(ge, le, langbot_version))
if self.ap.ver_mgr.compare_version_str(langbot_version, ge) < 0 or (
self.ap.ver_mgr.compare_version_str(langbot_version, le) > 0
):
raise Exception(
'LangBot 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{}'.format(
ge, le, langbot_version
)
)
return True
@@ -220,36 +225,30 @@ class EventContext:
if key not in self.__return_value__:
self.__return_value__[key] = []
self.__return_value__[key].append(ret)
async def reply(self, message_chain: platform_message.MessageChain):
"""回复此次消息请求
Args:
message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
"""
# TODO 添加 at_sender 和 quote_origin 参数
await self.event.query.adapter.reply_message(
message_source=self.event.query.message_event,
message=message_chain
message_source=self.event.query.message_event, message=message_chain
)
async def send_message(
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
"""主动发送消息
Args:
target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID
message (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
"""
await self.event.query.adapter.send_message(
target_type=target_type,
target_id=target_id,
message=message
target_type=target_type, target_id=target_id, message=message
)
def prevent_postorder(self):
@@ -281,10 +280,8 @@ class EventContext:
def is_prevented_postorder(self):
"""是否阻止后序插件执行"""
return self.__prevent_postorder__
def __init__(self, host: APIHost, event: events.BaseEventModel):
self.eid = EventContext.eid
self.host = host
self.event = event
@@ -297,16 +294,16 @@ class EventContext:
class RuntimeContainerStatus(enum.Enum):
"""插件容器状态"""
MOUNTED = "mounted"
MOUNTED = 'mounted'
"""已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态"""
INITIALIZED = "initialized"
INITIALIZED = 'initialized'
"""已初始化"""
class RuntimeContainer(pydantic.BaseModel):
"""运行时的插件容器
运行期间存储单个插件的信息
"""
@@ -352,9 +349,10 @@ class RuntimeContainer(pydantic.BaseModel):
plugin_inst: typing.Optional[BasePlugin] = None
"""插件实例"""
event_handlers: dict[typing.Type[events.BaseEventModel], typing.Callable[
[BasePlugin, EventContext], typing.Awaitable[None]
]] = {}
event_handlers: dict[
typing.Type[events.BaseEventModel],
typing.Callable[[BasePlugin, EventContext], typing.Awaitable[None]],
] = {}
"""事件处理器"""
tools: list[tools_entities.LLMFunction] = []
@@ -378,7 +376,7 @@ class RuntimeContainer(pydantic.BaseModel):
'pkg_path': self.pkg_path,
'enabled': self.enabled,
'priority': self.priority,
"config_schema": self.config_schema,
'config_schema': self.config_schema,
'event_handlers': {
event_name.__name__: handler.__name__
for event_name, handler in self.event_handlers.items()

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
class PluginSystemError(Exception):
message: str
def __init__(self, message: str):
@@ -10,15 +9,13 @@ class PluginSystemError(Exception):
def __str__(self):
return self.message
class PluginNotFoundError(PluginSystemError):
def __init__(self, message: str):
super().__init__(f"未找到插件: {message}")
super().__init__(f'未找到插件: {message}')
class PluginInstallerError(PluginSystemError):
def __init__(self, message: str):
super().__init__(f"安装器操作错误: {message}")
super().__init__(f'安装器操作错误: {message}')

View File

@@ -27,7 +27,7 @@ class PersonMessageReceived(BaseEventModel):
launcher_id: typing.Union[int, str]
"""发起对象ID(群号/QQ号)"""
sender_id: typing.Union[int, str]
"""发送者ID(QQ号)"""
@@ -40,7 +40,7 @@ class GroupMessageReceived(BaseEventModel):
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
message_chain: platform_message.MessageChain
@@ -52,7 +52,7 @@ class PersonNormalMessageReceived(BaseEventModel):
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
text_message: str
@@ -70,7 +70,7 @@ class PersonCommandSent(BaseEventModel):
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
command: str
@@ -94,7 +94,7 @@ class GroupNormalMessageReceived(BaseEventModel):
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
text_message: str
@@ -112,7 +112,7 @@ class GroupCommandSent(BaseEventModel):
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
command: str
@@ -136,7 +136,7 @@ class NormalMessageResponded(BaseEventModel):
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
session: core_entities.Session

View File

@@ -2,8 +2,8 @@
# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost
# 最早将于 v3.4 移除此模块
from . events import *
from . context import EventContext, APIHost as PluginHost
from .events import *
def emit(*args, **kwargs):
print('插件调用了已弃用的函数 pkg.plugin.host.emit()')
print('插件调用了已弃用的函数 pkg.plugin.host.emit()')

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import typing
import abc
from ..core import app, taskmgr
@@ -23,8 +22,7 @@ class PluginInstaller(metaclass=abc.ABCMeta):
plugin_source: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""安装插件
"""
"""安装插件"""
raise NotImplementedError
@abc.abstractmethod
@@ -33,17 +31,15 @@ class PluginInstaller(metaclass=abc.ABCMeta):
plugin_name: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""卸载插件
"""
"""卸载插件"""
raise NotImplementedError
@abc.abstractmethod
async def update_plugin(
self,
plugin_name: str,
plugin_source: str=None,
plugin_source: str = None,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""更新插件
"""
"""更新插件"""
raise NotImplementedError

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import re
import os
import shutil
import zipfile
import ssl
import certifi
@@ -18,33 +17,37 @@ from ...core import taskmgr
class GitHubRepoInstaller(installer.PluginInstaller):
"""GitHub仓库插件安装器
"""
"""GitHub仓库插件安装器"""
def get_github_plugin_repo_label(self, repo_url: str) -> list[str]:
"""获取username, repo"""
repo = re.findall(
r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)",
r'(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)',
repo_url,
)
if len(repo) > 0:
return repo[0].split("/")
return repo[0].split('/')
else:
return None
async def download_plugin_source_code(self, repo_url: str, target_path: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder()) -> str:
async def download_plugin_source_code(
self,
repo_url: str,
target_path: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
) -> str:
"""下载插件源码(全异步)"""
repo = self.get_github_plugin_repo_label(repo_url)
if repo is None:
raise errors.PluginInstallerError('仅支持GitHub仓库地址')
target_path += repo[1]
self.ap.logger.debug("正在下载源码...")
task_context.trace("下载源码...", "download-plugin-source-code")
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
self.ap.logger.debug('正在下载源码...')
task_context.trace('下载源码...', 'download-plugin-source-code')
zipball_url = f'https://api.github.com/repos/{"/".join(repo)}/zipball/HEAD'
zip_resp: bytes = None
# 创建自定义SSL上下文使用certifi提供的根证书
ssl_context = ssl.create_default_context(cafile=certifi.where())
@@ -52,41 +55,44 @@ class GitHubRepoInstaller(installer.PluginInstaller):
async with session.get(
url=zipball_url,
timeout=aiohttp.ClientTimeout(total=300),
ssl=ssl_context # 使用自定义SSL上下文来验证证书
ssl=ssl_context, # 使用自定义SSL上下文来验证证书
) as resp:
if resp.status != 200:
raise errors.PluginInstallerError(f"下载源码失败: {await resp.text()}")
raise errors.PluginInstallerError(
f'下载源码失败: {await resp.text()}'
)
zip_resp = await resp.read()
if await aiofiles_os.path.exists("temp/" + target_path):
await aioshutil.rmtree("temp/" + target_path)
if await aiofiles_os.path.exists('temp/' + target_path):
await aioshutil.rmtree('temp/' + target_path)
if await aiofiles_os.path.exists(target_path):
await aioshutil.rmtree(target_path)
await aiofiles_os.makedirs("temp/" + target_path)
await aiofiles_os.makedirs('temp/' + target_path)
async with aiofiles.open("temp/" + target_path + "/source.zip", "wb") as f:
async with aiofiles.open('temp/' + target_path + '/source.zip', 'wb') as f:
await f.write(zip_resp)
self.ap.logger.debug("解压中...")
task_context.trace("解压中...", "unzip-plugin-source-code")
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
zip_ref.extractall("temp/" + target_path)
await aiofiles_os.remove("temp/" + target_path + "/source.zip")
self.ap.logger.debug('解压中...')
task_context.trace('解压中...', 'unzip-plugin-source-code')
with zipfile.ZipFile('temp/' + target_path + '/source.zip', 'r') as zip_ref:
zip_ref.extractall('temp/' + target_path)
await aiofiles_os.remove('temp/' + target_path + '/source.zip')
import glob
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
await aioshutil.copytree(unzip_dir, target_path + "/")
unzip_dir = glob.glob('temp/' + target_path + '/*')[0]
await aioshutil.copytree(unzip_dir, target_path + '/')
await aioshutil.rmtree(unzip_dir)
self.ap.logger.debug("源码下载完成。")
self.ap.logger.debug('源码下载完成。')
return repo[1]
async def install_requirements(self, path: str):
if os.path.exists(path + "/requirements.txt"):
pkgmgr.install_requirements(path + "/requirements.txt")
if os.path.exists(path + '/requirements.txt'):
pkgmgr.install_requirements(path + '/requirements.txt')
async def install_plugin(
self,
@@ -94,12 +100,14 @@ class GitHubRepoInstaller(installer.PluginInstaller):
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""安装插件"""
task_context.trace("下载插件源码...", "install-plugin")
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/", task_context)
task_context.trace("安装插件依赖...", "install-plugin")
await self.install_requirements("plugins/" + repo_label)
task_context.trace("完成.", "install-plugin")
task_context.trace('下载插件源码...', 'install-plugin')
repo_label = await self.download_plugin_source_code(
plugin_source, 'plugins/', task_context
)
task_context.trace('安装插件依赖...', 'install-plugin')
await self.install_requirements('plugins/' + repo_label)
task_context.trace('完成.', 'install-plugin')
# Caution: in the v4.0, plugin without manifest will not be able to be updated
# await self.ap.plugin_mgr.setting.record_installed_plugin_source(
# "plugins/" + repo_label + '/', plugin_source
@@ -115,9 +123,9 @@ class GitHubRepoInstaller(installer.PluginInstaller):
if plugin_container is None:
raise errors.PluginInstallerError('插件不存在或未成功加载')
else:
task_context.trace("删除插件目录...", "uninstall-plugin")
task_context.trace('删除插件目录...', 'uninstall-plugin')
await aioshutil.rmtree(plugin_container.pkg_path)
task_context.trace("完成, 重新加载以生效.", "uninstall-plugin")
task_context.trace('完成, 重新加载以生效.', 'uninstall-plugin')
async def update_plugin(
self,
@@ -126,14 +134,14 @@ class GitHubRepoInstaller(installer.PluginInstaller):
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""更新插件"""
task_context.trace("更新插件...", "update-plugin")
task_context.trace('更新插件...', 'update-plugin')
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is None:
raise errors.PluginInstallerError('插件不存在或未成功加载')
else:
if plugin_container.plugin_repository:
plugin_source = plugin_container.plugin_repository
task_context.trace("转交安装任务.", "update-plugin")
task_context.trace('转交安装任务.', 'update-plugin')
await self.install_plugin(plugin_source, task_context)
else:
raise errors.PluginInstallerError('插件无源码信息,无法更新')
raise errors.PluginInstallerError('插件无源码信息,无法更新')

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
from abc import ABCMeta
import typing
import abc
from ..core import app
from . import context, events
from . import context
class PluginLoader(metaclass=abc.ABCMeta):
@@ -25,4 +23,3 @@ class PluginLoader(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def load_plugins(self):
pass

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import typing
import pkgutil
import importlib
import os
import traceback
from .. import loader, events, context, models
@@ -11,7 +10,6 @@ from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities
from ...utils import funcschema
from ...discover import engine as discover_engine
from ...utils import pkgmgr
class PluginLoader(loader.PluginLoader):
@@ -36,17 +34,17 @@ class PluginLoader(loader.PluginLoader):
"""初始化"""
def register(
self,
name: str,
description: str,
version: str,
author: str
) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]:
self, name: str, description: str, version: str, author: str
) -> typing.Callable[
[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]
]:
self.ap.logger.debug(f'注册插件 {name} {version} by {author}')
container = context.RuntimeContainer(
plugin_name=name,
plugin_label=discover_engine.I18nString(en_US=name, zh_CN=name),
plugin_description=discover_engine.I18nString(en_US=description, zh_CN=description),
plugin_description=discover_engine.I18nString(
en_US=description, zh_CN=description
),
plugin_version=version,
plugin_author=author,
plugin_repository='',
@@ -61,20 +59,21 @@ class PluginLoader(loader.PluginLoader):
def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]:
container.plugin_class = cls
return cls
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def on(
self,
event: typing.Type[events.BaseEventModel]
self, event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册过时的事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None:
async def handler(
plugin: context.BasePlugin, ctx: context.EventContext
) -> None:
args = {
'host': ctx.host,
'event': ctx,
@@ -82,12 +81,12 @@ class PluginLoader(loader.PluginLoader):
# 把 ctx.event 所有的属性都放到 args 里
# for k, v in ctx.event.dict().items():
# args[k] = v
# args[k] = v
for attr_name in ctx.event.__dict__.keys():
args[attr_name] = getattr(ctx.event, attr_name)
func(plugin, **args)
self._current_container.event_handlers[event] = handler
return func
@@ -98,20 +97,21 @@ class PluginLoader(loader.PluginLoader):
# 最早将于 v3.4 版本移除
def func(
self,
name: str=None,
name: str = None,
) -> typing.Callable:
"""注册过时的内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
function_name = (
self._current_container.plugin_name
+ '-'
+ (func.__name__ if name is None else name)
)
async def handler(
plugin: context.BasePlugin,
query: core_entities.Query,
*args,
**kwargs
plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs
):
return func(*args, **kwargs)
@@ -126,18 +126,19 @@ class PluginLoader(loader.PluginLoader):
self._current_container.tools.append(llm_function)
return func
return wrapper
def handler(
self,
event: typing.Type[events.BaseEventModel]
self, event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
if self._current_container is None: # None indicates this plugin is registered through manifest, so ignore it here
def wrapper(func: typing.Callable) -> typing.Callable:
if (
self._current_container is None
): # None indicates this plugin is registered through manifest, so ignore it here
return func
self._current_container.event_handlers[event] = func
@@ -148,17 +149,23 @@ class PluginLoader(loader.PluginLoader):
def llm_func(
self,
name: str=None,
name: str = None,
) -> typing.Callable:
"""注册内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
if self._current_container is None: # None indicates this plugin is registered through manifest, so ignore it here
if (
self._current_container is None
): # None indicates this plugin is registered through manifest, so ignore it here
return func
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
function_name = (
self._current_container.plugin_name
+ '-'
+ (func.__name__ if name is None else name)
)
llm_function = tools_entities.LLMFunction(
name=function_name,
@@ -171,43 +178,40 @@ class PluginLoader(loader.PluginLoader):
self._current_container.tools.append(llm_function)
return func
return wrapper
async def _walk_plugin_path(
self,
module,
prefix='',
path_prefix=''
):
"""遍历插件路径
"""
async def _walk_plugin_path(self, module, prefix='', path_prefix=''):
"""遍历插件路径"""
for item in pkgutil.iter_modules(module.__path__):
if item.ispkg:
await self._walk_plugin_path(
__import__(module.__name__ + "." + item.name, fromlist=[""]),
prefix + item.name + ".",
path_prefix + item.name + "/",
__import__(module.__name__ + '.' + item.name, fromlist=['']),
prefix + item.name + '.',
path_prefix + item.name + '/',
)
else:
try:
self._current_pkg_path = "plugins/" + path_prefix
self._current_module_path = "plugins/" + path_prefix + item.name + ".py"
self._current_pkg_path = 'plugins/' + path_prefix
self._current_module_path = (
'plugins/' + path_prefix + item.name + '.py'
)
self._current_container = None
importlib.import_module(module.__name__ + "." + item.name)
importlib.import_module(module.__name__ + '.' + item.name)
if self._current_container is not None:
self.plugins.append(self._current_container)
self.ap.logger.debug(f'插件 {self._current_container} 已加载')
except:
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
except Exception:
self.ap.logger.error(
f'加载插件模块 {prefix + item.name} 时发生错误'
)
traceback.print_exc()
async def load_plugins(self):
"""加载插件
"""
"""加载插件"""
setattr(models, 'register', self.register)
setattr(models, 'on', self.on)
setattr(models, 'func', self.func)
@@ -215,4 +219,4 @@ class PluginLoader(loader.PluginLoader):
setattr(context, 'register', self.register)
setattr(context, 'handler', self.handler)
setattr(context, 'llm_func', self.llm_func)
await self._walk_plugin_path(__import__("plugins", fromlist=[""]))
await self._walk_plugin_path(__import__('plugins', fromlist=['']))

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
import typing
import abc
import os
import traceback
from ...core import app
from .. import context, events, models
from .. import context, events
from .. import loader
from ...utils import funcschema
from ...provider.tools import entities as tools_entities
@@ -21,13 +20,12 @@ class PluginManifestLoader(loader.PluginLoader):
super().__init__(ap)
def handler(
self,
event: typing.Type[events.BaseEventModel]
self, event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
self._current_container.event_handlers[event] = func
return func
@@ -36,14 +34,18 @@ class PluginManifestLoader(loader.PluginLoader):
def llm_func(
self,
name: str=None,
name: str = None,
) -> typing.Callable:
"""注册内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
function_name = (
self._current_container.plugin_name
+ '-'
+ (func.__name__ if name is None else name)
)
llm_function = tools_entities.LLMFunction(
name=function_name,
@@ -56,7 +58,7 @@ class PluginManifestLoader(loader.PluginLoader):
self._current_container.tools.append(llm_function)
return func
return wrapper
async def load_plugins(self):
@@ -68,7 +70,11 @@ class PluginManifestLoader(loader.PluginLoader):
for plugin_manifest in plugin_manifests:
try:
config_schema = plugin_manifest.spec['config'] if 'config' in plugin_manifest.spec else []
config_schema = (
plugin_manifest.spec['config']
if 'config' in plugin_manifest.spec
else []
)
current_plugin_container = context.RuntimeContainer(
plugin_name=plugin_manifest.metadata.name,
@@ -77,7 +83,9 @@ class PluginManifestLoader(loader.PluginLoader):
plugin_version=plugin_manifest.metadata.version,
plugin_author=plugin_manifest.metadata.author,
plugin_repository=plugin_manifest.metadata.repository,
main_file=os.path.join(plugin_manifest.rel_dir, plugin_manifest.execution.python.path),
main_file=os.path.join(
plugin_manifest.rel_dir, plugin_manifest.execution.python.path
),
pkg_path=plugin_manifest.rel_dir,
config_schema=config_schema,
event_handlers={},
@@ -95,6 +103,8 @@ class PluginManifestLoader(loader.PluginLoader):
# TODO load component extensions
self.plugins.append(current_plugin_container)
except Exception as e:
self.ap.logger.error(f'加载插件 {plugin_manifest.metadata.name} 时发生错误')
except Exception:
self.ap.logger.error(
f'加载插件 {plugin_manifest.metadata.name} 时发生错误'
)
traceback.print_exc()

View File

@@ -1,10 +1,8 @@
from __future__ import annotations
import typing
import traceback
import sqlalchemy
import logging
from ..core import app, taskmgr
from . import context, loader, events, installer, models
@@ -28,28 +26,26 @@ class PluginManager:
def plugins(
self,
enabled: bool=None,
status: context.RuntimeContainerStatus=None,
enabled: bool = None,
status: context.RuntimeContainerStatus = None,
) -> list[context.RuntimeContainer]:
"""获取插件列表
"""
"""获取插件列表"""
plugins = self.plugin_containers
if enabled is not None:
plugins = [plugin for plugin in plugins if plugin.enabled == enabled]
if status is not None:
plugins = [plugin for plugin in plugins if plugin.status == status]
return plugins
def get_plugin(
self,
author: str,
plugin_name: str,
) -> context.RuntimeContainer:
"""通过作者和插件名获取插件
"""
"""通过作者和插件名获取插件"""
for plugin in self.plugins():
if plugin.plugin_author == author and plugin.plugin_name == plugin_name:
return plugin
@@ -88,20 +84,24 @@ class PluginManager:
self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugin_containers}')
async def load_plugin_settings(
self,
plugin_containers: list[context.RuntimeContainer]
self, plugin_containers: list[context.RuntimeContainer]
):
for plugin_container in plugin_containers:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_plugin.PluginSetting) \
.where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author)
.where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name)
sqlalchemy.select(persistence_plugin.PluginSetting)
.where(
persistence_plugin.PluginSetting.plugin_author
== plugin_container.plugin_author
)
.where(
persistence_plugin.PluginSetting.plugin_name
== plugin_container.plugin_name
)
)
setting = result.first()
if setting is None:
new_setting_data = {
'plugin_author': plugin_container.plugin_author,
'plugin_name': plugin_container.plugin_name,
@@ -111,7 +111,9 @@ class PluginManager:
}
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_plugin.PluginSetting).values(**new_setting_data)
sqlalchemy.insert(persistence_plugin.PluginSetting).values(
**new_setting_data
)
)
continue
else:
@@ -120,19 +122,23 @@ class PluginManager:
plugin_container.plugin_config = setting.config
async def dump_plugin_container_setting(
self,
plugin_container: context.RuntimeContainer
self, plugin_container: context.RuntimeContainer
):
"""保存单个插件容器的设置到数据库
"""
"""保存单个插件容器的设置到数据库"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_plugin.PluginSetting)
.where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author)
.where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name)
.where(
persistence_plugin.PluginSetting.plugin_author
== plugin_container.plugin_author
)
.where(
persistence_plugin.PluginSetting.plugin_name
== plugin_container.plugin_name
)
.values(
enabled=plugin_container.enabled,
priority=plugin_container.priority,
config=plugin_container.plugin_config
config=plugin_container.plugin_config,
)
)
@@ -160,13 +166,13 @@ class PluginManager:
async def destroy_plugin(self, plugin: context.RuntimeContainer):
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
return
self.ap.logger.debug(f'释放插件 {plugin.plugin_name}')
plugin.plugin_inst.__del__()
await plugin.plugin_inst.destroy()
plugin.plugin_inst = None
plugin.status = context.RuntimeContainerStatus.MOUNTED
async def destroy_plugins(self):
for plugin in self.plugins():
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
@@ -185,16 +191,15 @@ class PluginManager:
plugin_source: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""安装插件
"""
"""安装插件"""
await self.installer.install_plugin(plugin_source, task_context)
await self.ap.ctr_mgr.plugin.post_install_record(
{
"name": "unknown",
"remote": plugin_source,
"author": "unknown",
"version": "HEAD"
'name': 'unknown',
'remote': plugin_source,
'author': 'unknown',
'version': 'HEAD',
}
)
@@ -206,8 +211,7 @@ class PluginManager:
plugin_name: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""卸载插件
"""
"""卸载插件"""
plugin_container = self.get_plugin_by_name(plugin_name)
@@ -219,10 +223,10 @@ class PluginManager:
await self.ap.ctr_mgr.plugin.post_remove_record(
{
"name": plugin_name,
"remote": plugin_container.plugin_repository,
"author": plugin_container.plugin_author,
"version": plugin_container.plugin_version
'name': plugin_name,
'remote': plugin_container.plugin_repository,
'author': plugin_container.plugin_author,
'version': plugin_container.plugin_version,
}
)
@@ -232,80 +236,82 @@ class PluginManager:
async def update_plugin(
self,
plugin_name: str,
plugin_source: str=None,
plugin_source: str = None,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""更新插件
"""
"""更新插件"""
await self.installer.update_plugin(plugin_name, plugin_source, task_context)
plugin_container = self.get_plugin_by_name(plugin_name)
await self.ap.ctr_mgr.plugin.post_update_record(
plugin={
"name": plugin_name,
"remote": plugin_container.plugin_repository,
"author": plugin_container.plugin_author,
"version": plugin_container.plugin_version
'name': plugin_name,
'remote': plugin_container.plugin_repository,
'author': plugin_container.plugin_author,
'version': plugin_container.plugin_version,
},
old_version=plugin_container.plugin_version,
new_version="HEAD"
new_version='HEAD',
)
task_context.trace('重载插件..', 'reload-plugin')
await self.ap.reload(scope='plugin')
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
"""通过插件名获取插件
"""
"""通过插件名获取插件"""
for plugin in self.plugins():
if plugin.plugin_name == plugin_name:
return plugin
return None
async def emit_event(self, event: events.BaseEventModel) -> context.EventContext:
"""触发事件
"""
"""触发事件"""
ctx = context.EventContext(host=self.api_host, event=event)
ctx = context.EventContext(
host=self.api_host,
event=event
)
emitted_plugins: list[context.RuntimeContainer] = []
for plugin in self.plugins(
enabled=True,
status=context.RuntimeContainerStatus.INITIALIZED
enabled=True, status=context.RuntimeContainerStatus.INITIALIZED
):
if event.__class__ in plugin.event_handlers:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}')
self.ap.logger.debug(
f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}'
)
is_prevented_default_before_call = ctx.is_prevented_default()
try:
await plugin.event_handlers[event.__class__](
plugin.plugin_inst,
ctx
plugin.plugin_inst, ctx
)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
self.ap.logger.error(
f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}'
)
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
emitted_plugins.append(plugin)
if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
self.ap.logger.debug(
f'插件 {plugin.plugin_name} 阻止了默认行为执行'
)
if ctx.is_prevented_postorder():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
self.ap.logger.debug(
f'插件 {plugin.plugin_name} 阻止了后序插件的执行'
)
break
for key in ctx.__return_value__.keys():
if hasattr(ctx.event, key):
setattr(ctx.event, key, ctx.__return_value__[key][0])
self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}')
self.ap.logger.debug(
f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}'
)
if emitted_plugins:
plugins_info: list[dict] = [
@@ -313,13 +319,13 @@ class PluginManager:
'name': plugin.plugin_name,
'remote': plugin.plugin_repository,
'version': plugin.plugin_version,
'author': plugin.plugin_author
} for plugin in emitted_plugins
'author': plugin.plugin_author,
}
for plugin in emitted_plugins
]
await self.ap.ctr_mgr.usage.post_event_record(
plugins=plugins_info,
event_name=event.__class__.__name__
plugins=plugins_info, event_name=event.__class__.__name__
)
return ctx
@@ -330,7 +336,7 @@ class PluginManager:
if plugin.plugin_name == plugin_name:
if plugin.enabled == new_status:
return False
# 初始化/释放插件
if new_status:
await self.initialize_plugin(plugin)
@@ -338,7 +344,7 @@ class PluginManager:
await self.destroy_plugin(plugin)
plugin.enabled = new_status
await self.dump_plugin_container_setting(plugin)
break
@@ -348,7 +354,6 @@ class PluginManager:
return False
async def reorder_plugins(self, plugins: list[dict]):
for plugin in plugins:
plugin_name = plugin.get('name')
plugin_priority = plugin.get('priority')
@@ -363,7 +368,9 @@ class PluginManager:
for plugin in self.plugin_containers:
await self.dump_plugin_container_setting(plugin)
async def set_plugin_config(self, plugin_container: context.RuntimeContainer, new_config: dict):
async def set_plugin_config(
self, plugin_container: context.RuntimeContainer, new_config: dict
):
plugin_container.plugin_config = new_config
plugin_container.plugin_inst.config = new_config

View File

@@ -9,22 +9,20 @@ import typing
from .context import BasePlugin as Plugin
from .events import *
def register(
name: str,
description: str,
version: str,
author
name: str, description: str, version: str, author
) -> typing.Callable[[typing.Type[Plugin]], typing.Type[Plugin]]:
pass
def on(
event: typing.Type[BaseEventModel]
event: typing.Type[BaseEventModel],
) -> typing.Callable[[typing.Callable], typing.Callable]:
pass
def func(
name: str=None,
name: str = None,
) -> typing.Callable:
pass