refactor: 重构插件系统

This commit is contained in:
RockChinQ
2024-01-29 21:22:27 +08:00
parent b730f17eb6
commit 6cc4688660
53 changed files with 1307 additions and 1993 deletions

View File

@@ -4,7 +4,6 @@ import typing
from ..core import app, entities as core_entities
from ..provider import entities as llm_entities
from ..provider.session import entities as session_entities
from . import entities, operator, errors
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update
@@ -80,7 +79,7 @@ class CommandManager:
self,
command_text: str,
query: core_entities.Query,
session: session_entities.Session
session: core_entities.Session
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""

View File

@@ -6,7 +6,6 @@ import pydantic
import mirai
from ..core import app, entities as core_entities
from ..provider.session import entities as session_entities
from . import errors, operator
@@ -28,7 +27,7 @@ class ExecuteContext(pydantic.BaseModel):
query: core_entities.Query
session: session_entities.Session
session: core_entities.Session
command_text: str

View File

@@ -4,7 +4,6 @@ import typing
import abc
from ..core import app, entities as core_entities
from ..provider.session import entities as session_entities
from . import entities

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
from typing import AsyncGenerator
from .. import operator, entities, cmdmgr
from ...plugin import host as plugin_host
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
@@ -13,7 +12,10 @@ class FuncOperator(operator.CommandOperator):
reply_str = "当前已加载的内容函数: \n\n"
index = 1
for func in self.ap.tool_mgr.all_functions:
all_functions = await self.ap.tool_mgr.get_all_functions()
for func in all_functions:
reply_str += "{}. {}{}:\n{}\n\n".format(
index,
("(已禁用) " if not func.enable else ""),

View File

@@ -3,8 +3,6 @@ import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from ...plugin import host as plugin_host
from ...utils import updater
from ...core import app
@@ -20,16 +18,15 @@ class PluginOperator(operator.CommandOperator):
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = plugin_host.__plugins__
reply_str = "所有插件({}):\n".format(len(plugin_host.__plugins__))
plugin_list = self.ap.plugin_mgr.plugins
reply_str = "所有插件({}):\n".format(len(plugin_list))
idx = 0
for key in plugin_host.iter_plugins_name():
plugin = plugin_list[key]
for plugin in plugin_list:
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin['name'],
"[已禁用]" if not plugin['enabled'] else "",
plugin['description'],
plugin['version'], plugin['author'])
.format((idx+1), plugin.plugin_name,
"[已禁用]" if not plugin.enabled else "",
plugin.plugin_description,
plugin.plugin_version, plugin.plugin_author)
# TODO 从元数据调远程地址
# if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
@@ -63,7 +60,7 @@ class PluginGetOperator(operator.CommandOperator):
yield entities.CommandReturn(text="正在安装插件...")
try:
plugin_host.install_plugin(repo)
await self.ap.plugin_mgr.install_plugin(repo)
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
except Exception as e:
traceback.print_exc()
@@ -89,11 +86,11 @@ class PluginUpdateOperator(operator.CommandOperator):
plugin_name = context.crt_params[0]
try:
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name)
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_path_name is not None:
if plugin_container is not None:
yield entities.CommandReturn(text="正在更新插件...")
plugin_host.update_plugin(plugin_name)
await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
else:
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
@@ -115,17 +112,17 @@ class PluginUpdateAllOperator(operator.CommandOperator):
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = []
for key in plugin_host.__plugins__:
plugins.append(key)
plugins = [
p.plugin_name
for p in self.ap.plugin_mgr.plugins
]
if plugins:
yield entities.CommandReturn(text="正在更新插件...")
updated = []
try:
for plugin_name in plugins:
plugin_host.update_plugin(plugin_name)
await self.ap.plugin_mgr.update_plugin(plugin_name)
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
@@ -157,11 +154,11 @@ class PluginDelOperator(operator.CommandOperator):
plugin_name = context.crt_params[0]
try:
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name)
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_path_name is not None:
if plugin_container is not None:
yield entities.CommandReturn(text="正在删除插件...")
plugin_host.uninstall_plugin(plugin_name)
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
else:
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
@@ -171,12 +168,15 @@ class PluginDelOperator(operator.CommandOperator):
def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
if plugin_name in plugin_host.__plugins__:
plugin_host.__plugins__[plugin_name]['enabled'] = new_status
if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None:
for plugin in ap.plugin_mgr.plugins:
if plugin.plugin_name == plugin_name:
plugin.enabled = new_status
for func in ap.tool_mgr.all_functions:
if func.name.startswith(plugin_name+'-'):
func.enable = new_status
for func in plugin.content_functions:
func.enable = new_status
break
return True
else:

View File

@@ -4,7 +4,6 @@ import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from ...utils import updater
@operator.operator_class(
@@ -22,7 +21,7 @@ class UpdateCommand(operator.CommandOperator):
try:
yield entities.CommandReturn(text="正在进行更新...")
if updater.update_all():
if await self.ap.ver_mgr.update_all():
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新")
else:
yield entities.CommandReturn(text="当前已是最新版本")

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import typing
from .. import operator, cmdmgr, entities, errors
from ...utils import updater
@operator.operator_class(
@@ -17,10 +16,10 @@ class VersionCommand(operator.CommandOperator):
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f"当前版本: \n{updater.get_current_version_info()}"
reply_str = f"当前版本: \n{await self.ap.ver_mgr.get_current_version_info()}"
try:
if updater.is_new_version_available():
if await self.ap.ver_mgr.is_new_version_available():
reply_str += "\n\n有新版本可用, 使用 !update 更新"
except:
pass