refactor: move plugin setting to db

This commit is contained in:
Junyan Qin
2025-04-12 20:21:43 +08:00
parent 11342e75de
commit ebd091a9e0
10 changed files with 130 additions and 127 deletions

View File

@@ -44,20 +44,28 @@ class PluginsRouterGroup(group.RouterGroup):
'task_id': wrapper.id
})
@self.route('/<author>/<plugin_name>', methods=['DELETE'], auth_type=group.AuthType.USER_TOKEN)
@self.route('/<author>/<plugin_name>', methods=['GET', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
name=f'plugin-remove-{plugin_name}',
label=f'删除插件 {plugin_name}',
context=ctx
)
if quart.request.method == 'GET':
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None:
return self.http_status(404, -1, 'plugin not found')
return self.success(data={
'plugin': plugin.model_dump()
})
elif quart.request.method == 'DELETE':
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
name=f'plugin-remove-{plugin_name}',
label=f'删除插件 {plugin_name}',
context=ctx
)
return self.success(data={
'task_id': wrapper.id
})
return self.success(data={
'task_id': wrapper.id
})
@self.route('/reorder', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:

View File

@@ -0,0 +1,16 @@
import sqlalchemy
from .base import Base
class PluginSetting(Base):
"""插件配置"""
__tablename__ = 'plugin_settings'
plugin_author = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
plugin_name = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now())

View File

@@ -8,7 +8,7 @@ import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
import sqlalchemy
from . import database
from ..entity.persistence import base, user, model, pipeline, bot
from ..entity.persistence import base, user, model, pipeline, bot, plugin
from ..core import app
from .databases import sqlite

View File

@@ -339,6 +339,12 @@ class RuntimeContainer(pydantic.BaseModel):
priority: typing.Optional[int] = 0
"""优先级"""
config_schema: typing.Optional[list[dict]] = []
"""插件配置模板"""
plugin_config: typing.Optional[dict] = {}
"""插件配置"""
plugin_inst: typing.Optional[BasePlugin] = None
"""插件实例"""
@@ -389,6 +395,7 @@ class RuntimeContainer(pydantic.BaseModel):
'pkg_path': self.pkg_path,
'enabled': self.enabled,
'priority': self.priority,
"config_schema": self.config_schema,
'event_handlers': {
event_name.__name__: handler.__name__
for event_name, handler in self.event_handlers.items()

View File

@@ -99,9 +99,11 @@ class GitHubRepoInstaller(installer.PluginInstaller):
task_context.trace("安装插件依赖...", "install-plugin")
await self.install_requirements("plugins/" + repo_label)
task_context.trace("完成.", "install-plugin")
await self.ap.plugin_mgr.setting.record_installed_plugin_source(
"plugins/" + repo_label + '/', plugin_source
)
# Caution: in the v4.0, plugin without manifest will not be able to be updated
# await self.ap.plugin_mgr.setting.record_installed_plugin_source(
# "plugins/" + repo_label + '/', plugin_source
# )
async def uninstall_plugin(
self,

View File

@@ -133,7 +133,10 @@ class PluginLoader(loader.PluginLoader):
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
if self._current_container is None: # None indicates this plugin is registered through manifest, so ignore it here
return func
self._current_container.event_handlers[event] = func
return func
@@ -148,6 +151,9 @@ class PluginLoader(loader.PluginLoader):
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
if self._current_container is None: # None indicates this plugin is registered through manifest, so ignore it here
return func
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)

View File

@@ -68,6 +68,8 @@ class PluginManifestLoader(loader.PluginLoader):
for plugin_manifest in plugin_manifests:
try:
config_schema = plugin_manifest.spec['config'] if 'config' in plugin_manifest.spec else []
current_plugin_container = context.RuntimeContainer(
plugin_name=plugin_manifest.metadata.name,
plugin_label=plugin_manifest.metadata.label,
@@ -77,6 +79,7 @@ class PluginManifestLoader(loader.PluginLoader):
plugin_repository=plugin_manifest.metadata.repository,
main_file=os.path.join(plugin_manifest.rel_dir, plugin_manifest.execution.python.path),
pkg_path=plugin_manifest.rel_dir,
config_schema=config_schema,
event_handlers={},
tools=[],
)

View File

@@ -3,10 +3,13 @@ from __future__ import annotations
import typing
import traceback
import sqlalchemy
from ..core import app, taskmgr
from . import context, loader, events, installer, setting, models
from . import context, loader, events, installer, models
from .loaders import classic, manifest
from .installers import github
from ..entity.persistence import plugin as persistence_plugin
class PluginManager:
@@ -18,8 +21,6 @@ class PluginManager:
installer: installer.PluginInstaller
setting: setting.SettingManager
api_host: context.APIHost
plugin_containers: list[context.RuntimeContainer]
@@ -40,6 +41,18 @@ class PluginManager:
plugins = [plugin for plugin in plugins if plugin.status == status]
return plugins
def get_plugin(
self,
author: str,
plugin_name: str,
) -> context.RuntimeContainer:
"""通过作者和插件名获取插件
"""
for plugin in self.plugins():
if plugin.plugin_author == author and plugin.plugin_name == plugin_name:
return plugin
return None
def __init__(self, ap: app.Application):
self.ap = ap
@@ -48,7 +61,6 @@ class PluginManager:
manifest.PluginManifestLoader(ap),
]
self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap)
self.plugin_containers = []
@@ -56,23 +68,73 @@ class PluginManager:
for loader in self.loaders:
await loader.initialize()
await self.installer.initialize()
await self.setting.initialize()
await self.api_host.initialize()
setattr(models, 'require_ver', self.api_host.require_ver)
async def load_plugins(self):
self.ap.logger.info('Loading all plugins...')
for loader in self.loaders:
await loader.load_plugins()
self.plugin_containers.extend(loader.plugins)
await self.setting.sync_setting(self.plugin_containers)
await self.load_plugin_settings(self.plugin_containers)
# 按优先级倒序
self.plugin_containers.sort(key=lambda x: x.priority, reverse=True)
self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugin_containers}')
async def load_plugin_settings(
self,
plugin_containers: list[context.RuntimeContainer]
):
for plugin_container in plugin_containers:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_plugin.PluginSetting) \
.where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author)
.where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name)
)
setting = result.first()
if setting is None:
new_setting_data = {
'plugin_author': plugin_container.plugin_author,
'plugin_name': plugin_container.plugin_name,
'enabled': plugin_container.enabled,
'priority': plugin_container.priority,
'config': plugin_container.plugin_config,
}
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_plugin.PluginSetting).values(**new_setting_data)
)
continue
else:
plugin_container.enabled = setting.enabled
plugin_container.priority = setting.priority
plugin_container.plugin_config = setting.config
async def dump_plugin_container_setting(
self,
plugin_container: context.RuntimeContainer
):
"""保存单个插件容器的设置到数据库
"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_plugin.PluginSetting)
.where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author)
.where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name)
.values(
enabled=plugin_container.enabled,
priority=plugin_container.priority,
config=plugin_container.plugin_config
)
)
async def initialize_plugin(self, plugin: context.RuntimeContainer):
self.ap.logger.debug(f'初始化插件 {plugin.plugin_name}')
plugin.plugin_inst = plugin.plugin_class(self.api_host)
@@ -275,7 +337,7 @@ class PluginManager:
plugin.enabled = new_status
await self.setting.dump_container_setting(self.plugin_containers)
await self.dump_plugin_container_setting(self.plugin_containers)
break
@@ -296,4 +358,4 @@ class PluginManager:
self.plugin_containers.sort(key=lambda x: x.priority, reverse=True)
await self.setting.dump_container_setting(self.plugin_containers)
await self.dump_plugin_container_setting(self.plugin_containers)

View File

@@ -1,101 +0,0 @@
from __future__ import annotations
from ..core import app
from ..config import manager as cfg_mgr
from . import context
class SettingManager:
"""插件设置管理器"""
ap: app.Application
settings: cfg_mgr.ConfigManager
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.settings = self.ap.plugin_setting_meta
async def sync_setting(
self,
plugin_containers: list[context.RuntimeContainer],
):
"""同步设置
"""
not_matched_source_record = []
for value in self.settings.data['plugins']:
if 'name' not in value: # 只有远程地址的应用到pkg_path相同的插件容器上
matched = False
for plugin_container in plugin_containers:
if plugin_container.pkg_path == value['pkg_path']:
matched = True
plugin_container.plugin_repository = value['source']
break
if not matched:
not_matched_source_record.append(value)
else: # 正常的插件设置
for plugin_container in plugin_containers:
if plugin_container.plugin_name == value['name']:
plugin_container.set_from_setting_dict(value)
break
self.settings.data = {
'plugins': [
p.to_setting_dict()
for p in plugin_containers
]
}
self.settings.data['plugins'].extend(not_matched_source_record)
await self.settings.dump_config()
async def dump_container_setting(
self,
plugin_containers: list[context.RuntimeContainer]
):
"""保存插件容器设置
"""
for plugin in plugin_containers:
for ps in self.settings.data['plugins']:
if ps['name'] == plugin.plugin_name:
plugin_dict = plugin.to_setting_dict()
for key in plugin_dict:
ps[key] = plugin_dict[key]
break
await self.settings.dump_config()
async def record_installed_plugin_source(
self,
pkg_path: str,
source: str
):
found = False
for value in self.settings.data['plugins']:
if value['pkg_path'] == pkg_path:
value['source'] = source
found = True
break
if not found:
self.settings.data['plugins'].append(
{
'pkg_path': pkg_path,
'source': source
}
)
await self.settings.dump_config()

View File

@@ -219,7 +219,7 @@ class VersionManager:
try:
if await self.ap.ver_mgr.is_new_version_available():
return "有新版本可用,请使用管理员账号发送 !update 命令更新", logging.INFO
return "有新版本可用,根据文档更新https://docs.langbot.app/deploy/update.html", logging.INFO
except Exception as e:
return f"检查版本更新时出错: {e}", logging.WARNING