mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-17 19:24:19 +00:00
refactor: 重构插件系统
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import pkgutil
|
||||
import importlib
|
||||
import traceback
|
||||
|
||||
from CallingGPT.entities.namespace import get_func_schema
|
||||
|
||||
from .. import loader, events, context, models, host
|
||||
from ...core import entities as core_entities
|
||||
from ...provider.tools import entities as tools_entities
|
||||
|
||||
|
||||
class PluginLoader(loader.PluginLoader):
|
||||
"""加载 plugins/ 目录下的插件"""
|
||||
|
||||
_current_pkg_path = ''
|
||||
|
||||
_current_module_path = ''
|
||||
|
||||
_current_container: context.RuntimeContainer = None
|
||||
|
||||
containers: list[context.RuntimeContainer] = []
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化"""
|
||||
setattr(models, 'register', self.register)
|
||||
setattr(models, 'on', self.on)
|
||||
setattr(models, 'func', self.func)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
version: str,
|
||||
author: str
|
||||
) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]:
|
||||
self.ap.logger.debug(f'注册插件 {name} {version} by {author}')
|
||||
container = context.RuntimeContainer(
|
||||
plugin_name=name,
|
||||
plugin_description=description,
|
||||
plugin_version=version,
|
||||
plugin_author=author,
|
||||
plugin_source='',
|
||||
pkg_path=self._current_pkg_path,
|
||||
main_file=self._current_module_path,
|
||||
event_handlers={},
|
||||
content_functions=[],
|
||||
)
|
||||
|
||||
self._current_container = container
|
||||
|
||||
def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]:
|
||||
container.plugin_class = cls
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def on(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
"""注册过时的事件处理器"""
|
||||
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None:
|
||||
args = {
|
||||
'host': ctx.host,
|
||||
'event': ctx,
|
||||
}
|
||||
|
||||
# 把 ctx.event 所有的属性都放到 args 里
|
||||
for k, v in ctx.event.dict().items():
|
||||
args[k] = v
|
||||
|
||||
await func(plugin, **args)
|
||||
|
||||
self._current_container.event_handlers[event] = handler
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def func(
|
||||
self,
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
"""注册过时的内容函数"""
|
||||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
async def handler(
|
||||
query: core_entities.Query,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
llm_function = tools_entities.LLMFunction(
|
||||
name=function_name,
|
||||
human_desc='',
|
||||
description=function_schema['description'],
|
||||
enable=True,
|
||||
parameters=function_schema['parameters'],
|
||||
func=handler,
|
||||
)
|
||||
|
||||
self._current_container.content_functions.append(llm_function)
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
async def _walk_plugin_path(
|
||||
self,
|
||||
module,
|
||||
prefix='',
|
||||
path_prefix=''
|
||||
):
|
||||
"""遍历插件路径
|
||||
"""
|
||||
for item in pkgutil.iter_modules(module.__path__):
|
||||
if item.ispkg:
|
||||
await self._walk_plugin_path(
|
||||
__import__(module.__name__ + "." + item.name, fromlist=[""]),
|
||||
prefix + item.name + ".",
|
||||
path_prefix + item.name + "/",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self._current_pkg_path = "plugins/" + path_prefix
|
||||
self._current_module_path = "plugins/" + path_prefix + item.name + ".py"
|
||||
|
||||
self._current_container = None
|
||||
|
||||
importlib.import_module(module.__name__ + "." + item.name)
|
||||
|
||||
if self._current_container is not None:
|
||||
self.containers.append(self._current_container)
|
||||
self.ap.logger.debug(f'插件 {self._current_container} 已加载')
|
||||
except:
|
||||
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
|
||||
traceback.print_exc()
|
||||
|
||||
async def load_plugins(self) -> list[context.RuntimeContainer]:
|
||||
"""加载插件
|
||||
"""
|
||||
await self._walk_plugin_path(__import__("plugins", fromlist=[""]))
|
||||
|
||||
return self.containers
|
||||
Reference in New Issue
Block a user