refactor: 重构部分插件管理逻辑

This commit is contained in:
Junyan Qin
2024-11-16 16:13:02 +08:00
parent bb219889e5
commit 658eb278c4
13 changed files with 219 additions and 158 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import typing
import abc
import pydantic
import enum
from . import events
from ..provider.tools import entities as tools_entities
@@ -85,10 +86,18 @@ class BasePlugin(metaclass=abc.ABCMeta):
"""应用程序对象"""
def __init__(self, host: APIHost):
"""初始化阶段被调用"""
self.host = host
async def initialize(self):
"""初始化插件"""
"""初始化阶段被调用"""
pass
async def destroy(self):
"""释放/禁用插件时被调用"""
pass
def __del__(self):
pass
@@ -247,6 +256,16 @@ class EventContext:
EventContext.eid += 1
class RuntimeContainerStatus(enum.Enum):
"""插件容器状态"""
MOUNTED = "mounted"
"""已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态"""
INITIALIZED = "initialized"
"""已初始化"""
class RuntimeContainer(pydantic.BaseModel):
"""运行时的插件容器
@@ -294,6 +313,9 @@ class RuntimeContainer(pydantic.BaseModel):
content_functions: list[tools_entities.LLMFunction] = []
"""内容函数"""
status: RuntimeContainerStatus = RuntimeContainerStatus.MOUNTED
"""插件状态"""
class Config:
arbitrary_types_allowed = True
@@ -318,9 +340,6 @@ class RuntimeContainer(pydantic.BaseModel):
self.priority = setting['priority']
self.enabled = setting['enabled']
for function in self.content_functions:
function.enable = self.enabled
def model_dump(self, *args, **kwargs):
return {
'name': self.plugin_name,
@@ -342,9 +361,9 @@ class RuntimeContainer(pydantic.BaseModel):
'human_desc': function.human_desc,
'description': function.description,
'parameters': function.parameters,
'enable': function.enable,
'func': function.func.__name__,
}
for function in self.content_functions
],
'status': self.status.value,
}

View File

@@ -13,13 +13,16 @@ class PluginLoader(metaclass=abc.ABCMeta):
ap: app.Application
plugins: list[context.RuntimeContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.plugins = []
async def initialize(self):
pass
@abc.abstractmethod
async def load_plugins(self) -> list[context.RuntimeContainer]:
async def load_plugins(self):
pass

View File

@@ -20,7 +20,14 @@ class PluginLoader(loader.PluginLoader):
_current_container: context.RuntimeContainer = None
containers: list[context.RuntimeContainer] = []
plugins: list[context.RuntimeContainer] = []
def __init__(self, ap):
self.ap = ap
self.plugins = []
self._current_pkg_path = ''
self._current_module_path = ''
self._current_container = None
async def initialize(self):
"""初始化"""
@@ -77,8 +84,10 @@ class PluginLoader(loader.PluginLoader):
}
# 把 ctx.event 所有的属性都放到 args 里
for k, v in ctx.event.dict().items():
args[k] = v
# for k, v in ctx.event.dict().items():
# args[k] = v
for attr_name in ctx.event.__dict__.keys():
args[attr_name] = getattr(ctx.event, attr_name)
func(plugin, **args)
@@ -113,7 +122,6 @@ class PluginLoader(loader.PluginLoader):
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=handler,
)
@@ -153,7 +161,6 @@ class PluginLoader(loader.PluginLoader):
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=func,
)
@@ -189,15 +196,13 @@ class PluginLoader(loader.PluginLoader):
importlib.import_module(module.__name__ + "." + item.name)
if self._current_container is not None:
self.containers.append(self._current_container)
self.plugins.append(self._current_container)
self.ap.logger.debug(f'插件 {self._current_container} 已加载')
except:
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
traceback.print_exc()
async def load_plugins(self) -> list[context.RuntimeContainer]:
async def load_plugins(self):
"""加载插件
"""
await self._walk_plugin_path(__import__("plugins", fromlist=[""]))
return self.containers

View File

@@ -22,7 +22,22 @@ class PluginManager:
api_host: context.APIHost
plugins: list[context.RuntimeContainer]
def plugins(
self,
enabled: bool=None,
status: context.RuntimeContainerStatus=None,
) -> list[context.RuntimeContainer]:
"""获取插件列表
"""
plugins = self.loader.plugins
if enabled is not None:
plugins = [plugin for plugin in plugins if plugin.enabled == enabled]
if status is not None:
plugins = [plugin for plugin in plugins if plugin.status == status]
return plugins
def __init__(self, ap: app.Application):
self.ap = ap
@@ -30,7 +45,6 @@ class PluginManager:
self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap)
self.plugins = []
async def initialize(self):
await self.loader.initialize()
@@ -41,27 +55,57 @@ class PluginManager:
setattr(models, 'require_ver', self.api_host.require_ver)
async def load_plugins(self):
self.plugins = await self.loader.load_plugins()
await self.loader.load_plugins()
await self.setting.sync_setting(self.plugins)
await self.setting.sync_setting(self.loader.plugins)
# 按优先级倒序
self.plugins.sort(key=lambda x: x.priority, reverse=True)
self.loader.plugins.sort(key=lambda x: x.priority, reverse=True)
self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugins}')
self.ap.logger.debug(f'优先级排序后的插件列表 {self.loader.plugins}')
async def initialize_plugin(self, plugin: context.RuntimeContainer):
self.ap.logger.debug(f'初始化插件 {plugin.plugin_name}')
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
plugin.status = context.RuntimeContainerStatus.INITIALIZED
async def initialize_plugins(self):
for plugin in self.plugins:
for plugin in self.plugins():
if not plugin.enabled:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 未启用,跳过初始化')
continue
try:
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
await self.initialize_plugin(plugin)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
self.ap.logger.exception(e)
continue
async def destroy_plugin(self, plugin: context.RuntimeContainer):
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
return
self.ap.logger.debug(f'释放插件 {plugin.plugin_name}')
await plugin.plugin_inst.destroy()
plugin.plugin_inst = None
plugin.status = context.RuntimeContainerStatus.MOUNTED
async def destroy_plugins(self):
for plugin in self.plugins():
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 未初始化,跳过释放')
continue
try:
await self.destroy_plugin(plugin)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 释放失败: {e}')
self.ap.logger.exception(e)
continue
async def install_plugin(
self,
plugin_source: str,
@@ -127,7 +171,7 @@ class PluginManager:
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
"""通过插件名获取插件
"""
for plugin in self.plugins:
for plugin in self.plugins():
if plugin.plugin_name == plugin_name:
return plugin
return None
@@ -143,30 +187,32 @@ class PluginManager:
emitted_plugins: list[context.RuntimeContainer] = []
for plugin in self.plugins:
if plugin.enabled:
if event.__class__ in plugin.event_handlers:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}')
is_prevented_default_before_call = ctx.is_prevented_default()
for plugin in self.plugins(
enabled=True,
status=context.RuntimeContainerStatus.INITIALIZED
):
if event.__class__ in plugin.event_handlers:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}')
is_prevented_default_before_call = ctx.is_prevented_default()
try:
await plugin.event_handlers[event.__class__](
plugin.plugin_inst,
ctx
)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
emitted_plugins.append(plugin)
try:
await plugin.event_handlers[event.__class__](
plugin.plugin_inst,
ctx
)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
emitted_plugins.append(plugin)
if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
if ctx.is_prevented_postorder():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
break
if ctx.is_prevented_postorder():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
break
for key in ctx.__return_value__.keys():
if hasattr(ctx.event, key):
@@ -191,16 +237,22 @@ class PluginManager:
return ctx
async def update_plugin_status(self, plugin_name: str, new_status: bool):
async def update_plugin_switch(self, plugin_name: str, new_status: bool):
if self.get_plugin_by_name(plugin_name) is not None:
for plugin in self.plugins:
for plugin in self.plugins():
if plugin.plugin_name == plugin_name:
plugin.enabled = new_status
for func in plugin.content_functions:
func.enable = new_status
if plugin.enabled == new_status:
return False
await self.setting.dump_container_setting(self.plugins)
# 初始化/释放插件
if new_status:
await self.initialize_plugin(plugin)
else:
await self.destroy_plugin(plugin)
plugin.enabled = new_status
await self.setting.dump_container_setting(self.loader.plugins)
break
@@ -214,11 +266,11 @@ class PluginManager:
plugin_name = plugin.get('name')
plugin_priority = plugin.get('priority')
for plugin in self.plugins:
for plugin in self.loader.plugins:
if plugin.plugin_name == plugin_name:
plugin.priority = plugin_priority
break
self.plugins.sort(key=lambda x: x.priority, reverse=True)
self.loader.plugins.sort(key=lambda x: x.priority, reverse=True)
await self.setting.dump_container_setting(self.plugins)
await self.setting.dump_container_setting(self.loader.plugins)