diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 41097f27..3b44e9cc 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -10,7 +10,6 @@ required_deps = { "botpy": "qq-botpy", "PIL": "pillow", "nakuru": "nakuru-project-idk", - "CallingGPT": "CallingGPT", "tiktoken": "tiktoken", "yaml": "pyyaml", "aiohttp": "aiohttp", diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index a982232f..329914f0 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -13,6 +13,17 @@ class BasePlugin(metaclass=abc.ABCMeta): """插件基类""" host: APIHost + """API宿主""" + + ap: app.Application + """应用程序对象""" + + def __init__(self, host: APIHost): + self.host = host + + async def initialize(self): + """初始化插件""" + pass class APIHost: @@ -61,8 +72,10 @@ class EventContext: """事件编号""" host: APIHost = None + """API宿主""" event: events.BaseEventModel = None + """此次事件的对象,具体类型为handler注册时指定监听的类型,可查看events.py中的定义""" __prevent_default__ = False """是否阻止默认行为""" diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py index fe67a82d..bcb8e5c8 100644 --- a/pkg/plugin/events.py +++ b/pkg/plugin/events.py @@ -10,8 +10,10 @@ from ..provider import entities as llm_entities class BaseEventModel(pydantic.BaseModel): + """事件模型基类""" query: typing.Union[core_entities.Query, None] + """此次请求的query对象,可能为None""" class Config: arbitrary_types_allowed = True diff --git a/pkg/plugin/loaders/legacy.py b/pkg/plugin/loaders/classic.py similarity index 74% rename from pkg/plugin/loaders/legacy.py rename to pkg/plugin/loaders/classic.py index 9bbee7c0..d5be6ace 100644 --- a/pkg/plugin/loaders/legacy.py +++ b/pkg/plugin/loaders/classic.py @@ -5,11 +5,10 @@ import pkgutil import importlib import traceback -from CallingGPT.entities.namespace import get_func_schema - from .. import loader, events, context, models, host from ...core import entities as core_entities from ...provider.tools import entities as tools_entities +from ...utils import funcschema class PluginLoader(loader.PluginLoader): @@ -29,6 +28,9 @@ class PluginLoader(loader.PluginLoader): setattr(models, 'on', self.on) setattr(models, 'func', self.func) + setattr(models, 'handler', self.handler) + setattr(models, 'llm_func', self.llm_func) + def register( self, name: str, @@ -57,6 +59,8 @@ class PluginLoader(loader.PluginLoader): return wrapper + # 过时 + # 最早将于 v3.4 版本移除 def on( self, event: typing.Type[events.BaseEventModel] @@ -83,6 +87,8 @@ class PluginLoader(loader.PluginLoader): return wrapper + # 过时 + # 最早将于 v3.4 版本移除 def func( self, name: str=None, @@ -91,10 +97,11 @@ class PluginLoader(loader.PluginLoader): self.ap.logger.debug(f'注册内容函数 {name}') def wrapper(func: typing.Callable) -> typing.Callable: - function_schema = get_func_schema(func) + function_schema = funcschema.get_func_schema(func) function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) async def handler( + plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs @@ -116,6 +123,46 @@ class PluginLoader(loader.PluginLoader): return wrapper + def handler( + self, + event: typing.Type[events.BaseEventModel] + ) -> typing.Callable[[typing.Callable], typing.Callable]: + """注册事件处理器""" + self.ap.logger.debug(f'注册事件处理器 {event.__name__}') + def wrapper(func: typing.Callable) -> typing.Callable: + + self._current_container.event_handlers[event] = func + + return func + + return wrapper + + def llm_func( + self, + name: str=None, + ) -> typing.Callable: + """注册内容函数""" + self.ap.logger.debug(f'注册内容函数 {name}') + def wrapper(func: typing.Callable) -> typing.Callable: + + function_schema = funcschema.get_func_schema(func) + function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) + + llm_function = tools_entities.LLMFunction( + name=function_name, + human_desc='', + description=function_schema['description'], + enable=True, + parameters=function_schema['parameters'], + func=func, + ) + + self._current_container.content_functions.append(llm_function) + + return func + + return wrapper + async def _walk_plugin_path( self, module, diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py index 06e94f98..13a114a5 100644 --- a/pkg/plugin/manager.py +++ b/pkg/plugin/manager.py @@ -5,7 +5,7 @@ import traceback from ..core import app from . import context, loader, events, installer, setting, models -from .loaders import legacy +from .loaders import classic from .installers import github @@ -26,7 +26,7 @@ class PluginManager: def __init__(self, ap: app.Application): self.ap = ap - self.loader = legacy.PluginLoader(ap) + self.loader = classic.PluginLoader(ap) self.installer = github.GitHubRepoInstaller(ap) self.setting = setting.SettingManager(ap) self.api_host = context.APIHost(ap) @@ -52,6 +52,9 @@ class PluginManager: for plugin in self.plugins: try: plugin.plugin_inst = plugin.plugin_class(self.api_host) + plugin.plugin_inst.ap = self.ap + plugin.plugin_inst.host = self.api_host + await plugin.plugin_inst.initialize() except Exception as e: self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}') self.ap.logger.exception(e) diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index 972eed11..642305e4 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -24,3 +24,15 @@ def func( name: str=None, ) -> typing.Callable: pass + + +def handler( + event: typing.Type[BaseEventModel] +) -> typing.Callable[[typing.Callable], typing.Callable]: + pass + + +def llm_func( + name: str=None, +) -> typing.Callable: + pass \ No newline at end of file diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 72c892bb..616de713 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -5,6 +5,7 @@ import traceback from ...core import app, entities as core_entities from . import entities +from ...plugin import context as plugin_context class ToolManager: @@ -28,6 +29,15 @@ class ToolManager: return function return None + async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: + """获取函数和插件 + """ + for plugin in self.ap.plugin_mgr.plugins: + for function in plugin.content_functions: + if function.name == name: + return function, plugin + return None, None + async def get_all_functions(self) -> list[entities.LLMFunction]: """获取所有函数 """ @@ -68,7 +78,7 @@ class ToolManager: try: - function = await self.get_function(name) + function, plugin = await self.get_function_and_plugin(name) if function is None: return None @@ -79,7 +89,7 @@ class ToolManager: **parameters } - return await function.func(**parameters) + return await function.func(plugin, **parameters) except Exception as e: self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') traceback.print_exc() diff --git a/pkg/utils/funcschema.py b/pkg/utils/funcschema.py new file mode 100644 index 00000000..c39b4886 --- /dev/null +++ b/pkg/utils/funcschema.py @@ -0,0 +1,116 @@ +import sys +import re +import inspect + + +def get_func_schema(function: callable) -> dict: + """ + Return the data schema of a function. + { + "function": function, + "description": "function description", + "parameters": { + "type": "object", + "properties": { + "parameter_a": { + "type": "str", + "description": "parameter_a description" + }, + "parameter_b": { + "type": "int", + "description": "parameter_b description" + }, + "parameter_c": { + "type": "str", + "description": "parameter_c description", + "enum": ["a", "b", "c"] + }, + }, + "required": ["parameter_a", "parameter_b"] + } + } + """ + func_doc = function.__doc__ + # Google Style Docstring + if func_doc is None: + raise Exception("Function {} has no docstring.".format(function.__name__)) + func_doc = func_doc.strip().replace(' ','').replace('\t', '') + # extract doc of args from docstring + doc_spt = func_doc.split('\n\n') + desc = doc_spt[0] + args = doc_spt[1] if len(doc_spt) > 1 else "" + returns = doc_spt[2] if len(doc_spt) > 2 else "" + + # extract args + # delete the first line of args + arg_lines = args.split('\n')[1:] + arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args) + args_doc = {} + for arg_line in arg_lines: + doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line) + if len(doc_tuple) == 0: + continue + args_doc[doc_tuple[0][0]] = doc_tuple[0][3] + + # extract returns + return_doc_list = re.findall(r'(\w+):\s*(.*)', returns) + + params = enumerate(inspect.signature(function).parameters.values()) + parameters = { + "type": "object", + "required": [], + "properties": {}, + } + + + for i, param in params: + + # 排除 self, query + if param.name in ['self', 'query']: + continue + + param_type = param.annotation.__name__ + + type_name_mapping = { + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + "list": "array", + "dict": "object", + } + + if param_type in type_name_mapping: + param_type = type_name_mapping[param_type] + + parameters['properties'][param.name] = { + "type": param_type, + "description": args_doc[param.name], + } + + # add schema for array + if param_type == "array": + # extract type of array, the int of list[int] + # use re + array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation)) + + array_type = 'string' + + if len(array_type_tuple) > 0: + array_type = array_type_tuple[0] + + if array_type in type_name_mapping: + array_type = type_name_mapping[array_type] + + parameters['properties'][param.name]["items"] = { + "type": array_type, + } + + if param.default is inspect.Parameter.empty: + parameters["required"].append(param.name) + + return { + "function": function, + "description": desc, + "parameters": parameters, + } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6a3e718c..28c0ecb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ aiocqhttp qq-botpy nakuru-project-idk Pillow -CallingGPT tiktoken PyYaml aiohttp