feat: 异步风格插件方法注册器

This commit is contained in:
RockChinQ
2024-03-20 15:09:47 +08:00
parent fa823de6b0
commit 52a7c25540
9 changed files with 210 additions and 9 deletions

View File

@@ -13,6 +13,17 @@ class BasePlugin(metaclass=abc.ABCMeta):
"""插件基类"""
host: APIHost
"""API宿主"""
ap: app.Application
"""应用程序对象"""
def __init__(self, host: APIHost):
self.host = host
async def initialize(self):
"""初始化插件"""
pass
class APIHost:
@@ -61,8 +72,10 @@ class EventContext:
"""事件编号"""
host: APIHost = None
"""API宿主"""
event: events.BaseEventModel = None
"""此次事件的对象具体类型为handler注册时指定监听的类型可查看events.py中的定义"""
__prevent_default__ = False
"""是否阻止默认行为"""

View File

@@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
class BaseEventModel(pydantic.BaseModel):
"""事件模型基类"""
query: typing.Union[core_entities.Query, None]
"""此次请求的query对象可能为None"""
class Config:
arbitrary_types_allowed = True

View File

@@ -5,11 +5,10 @@ import pkgutil
import importlib
import traceback
from CallingGPT.entities.namespace import get_func_schema
from .. import loader, events, context, models, host
from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities
from ...utils import funcschema
class PluginLoader(loader.PluginLoader):
@@ -29,6 +28,9 @@ class PluginLoader(loader.PluginLoader):
setattr(models, 'on', self.on)
setattr(models, 'func', self.func)
setattr(models, 'handler', self.handler)
setattr(models, 'llm_func', self.llm_func)
def register(
self,
name: str,
@@ -57,6 +59,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def on(
self,
event: typing.Type[events.BaseEventModel]
@@ -83,6 +87,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def func(
self,
name: str=None,
@@ -91,10 +97,11 @@ class PluginLoader(loader.PluginLoader):
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = get_func_schema(func)
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
async def handler(
plugin: context.BasePlugin,
query: core_entities.Query,
*args,
**kwargs
@@ -116,6 +123,46 @@ class PluginLoader(loader.PluginLoader):
return wrapper
def handler(
self,
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
self._current_container.event_handlers[event] = func
return func
return wrapper
def llm_func(
self,
name: str=None,
) -> typing.Callable:
"""注册内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
llm_function = tools_entities.LLMFunction(
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=func,
)
self._current_container.content_functions.append(llm_function)
return func
return wrapper
async def _walk_plugin_path(
self,
module,

View File

@@ -5,7 +5,7 @@ import traceback
from ..core import app
from . import context, loader, events, installer, setting, models
from .loaders import legacy
from .loaders import classic
from .installers import github
@@ -26,7 +26,7 @@ class PluginManager:
def __init__(self, ap: app.Application):
self.ap = ap
self.loader = legacy.PluginLoader(ap)
self.loader = classic.PluginLoader(ap)
self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap)
@@ -52,6 +52,9 @@ class PluginManager:
for plugin in self.plugins:
try:
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
self.ap.logger.exception(e)

View File

@@ -24,3 +24,15 @@ def func(
name: str=None,
) -> typing.Callable:
pass
def handler(
event: typing.Type[BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
pass
def llm_func(
name: str=None,
) -> typing.Callable:
pass