Merge pull request #920 from RockChinQ/feat/lifetime-controlling

Feat: 生命周期和热重载
This commit is contained in:
Junyan Qin
2024-11-16 17:19:42 +08:00
committed by GitHub
29 changed files with 447 additions and 367 deletions

13
main.py
View File

@@ -3,13 +3,14 @@
# QChatGPT/main.py
asciiart = r"""
___ ___ _ _ ___ ___ _____
/ _ \ / __| |_ __ _| |_ / __| _ \_ _|
| (_) | (__| ' \/ _` | _| (_ | _/ | |
\__\_\\___|_||_\__,_|\__|\___|_| |_|
_ ___ _
| | __ _ _ _ __ _| _ ) ___| |_
| |__/ _` | ' \/ _` | _ \/ _ \ _|
|____\__,_|_||_\__, |___/\___/\__|
|___/
⭐️开源地址: https://github.com/RockChinQ/QChatGPT
📖文档地址: https://q.rkcn.top
⭐️开源地址: https://github.com/RockChinQ/LangBot
📖文档地址: https://docs.langbot.app
"""

View File

@@ -15,7 +15,7 @@ class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
plugins = self.ap.plugin_mgr.plugins
plugins = self.ap.plugin_mgr.plugins()
plugins_data = [plugin.model_dump() for plugin in plugins]
@@ -27,7 +27,7 @@ class PluginsRouterGroup(group.RouterGroup):
async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json
target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_status(plugin_name, target_enabled)
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
return self.success()
@self.route('/<author>/<plugin_name>/update', methods=['POST'])

View File

@@ -39,3 +39,25 @@ class SystemRouterGroup(group.RouterGroup):
return self.http_status(404, 404, "Task not found")
return self.success(data=task.to_dict())
@self.route('/reload', methods=['POST'])
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get("scope")
await self.ap.reload(
scope=scope
)
return self.success()
@self.route('/_debug/exec', methods=['POST'])
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, "Forbidden")
py_code = await quart.request.data
ap = self.ap
return self.success(data=exec(py_code, {"ap": ap}))

View File

@@ -6,7 +6,7 @@ import os
import quart
import quart_cors
from ....core import app
from ....core import app, entities as core_entities
from .groups import logs, system, settings, plugins, stats
from . import group
@@ -32,15 +32,26 @@ class HTTPController:
while True:
await asyncio.sleep(1)
async def exception_handler(*args, **kwargs):
try:
await self.quart_app.run_task(
*args, **kwargs
)
except Exception as e:
self.ap.logger.error(f"启动 HTTP 服务失败: {e}")
self.ap.task_mgr.create_task(
self.quart_app.run_task(
exception_handler(
host=self.ap.system_cfg.data["http-api"]["host"],
port=self.ap.system_cfg.data["http-api"]["port"],
shutdown_trigger=shutdown_trigger_placeholder,
),
name="http-api-quart",
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
# await asyncio.sleep(5)
async def register_routes(self) -> None:
@self.quart_app.route("/healthz")

View File

@@ -9,7 +9,7 @@ import asyncio
import aiohttp
import requests
from ...core import app
from ...core import app, entities as core_entities
class APIGroup(metaclass=abc.ABCMeta):
@@ -65,14 +65,12 @@ class APIGroup(metaclass=abc.ABCMeta):
**kwargs,
) -> asyncio.Task:
"""执行请求"""
# task = asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
# self.ap.asyncio_tasks.append(task)
return self.ap.task_mgr.create_task(
self._do(method, path, data, params, headers, **kwargs),
kind="telemetry-operation",
name=f"{method} {path}",
scopes=[core_entities.LifecycleControlScope.APPLICATION],
).task
def gen_rid(self):

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import AsyncGenerator
from .. import operator, entities, cmdmgr
from ...plugin import context as plugin_context
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
@@ -9,16 +10,18 @@ class FuncOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已加载的内容函数: \n\n"
reply_str = "当前已启用的内容函数: \n\n"
index = 1
all_functions = await self.ap.tool_mgr.get_all_functions()
all_functions = await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED,
)
for func in all_functions:
reply_str += "{}. {}{}:\n{}\n\n".format(
reply_str += "{}. {}:\n{}\n\n".format(
index,
("(已禁用) " if not func.enable else ""),
func.name,
func.description,
)

View File

@@ -18,7 +18,7 @@ class PluginOperator(operator.CommandOperator):
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins
plugin_list = self.ap.plugin_mgr.plugins()
reply_str = "所有插件({}):\n".format(len(plugin_list))
idx = 0
for plugin in plugin_list:
@@ -110,7 +110,7 @@ class PluginUpdateAllOperator(operator.CommandOperator):
try:
plugins = [
p.plugin_name
for p in self.ap.plugin_mgr.plugins
for p in self.ap.plugin_mgr.plugins()
]
if plugins:
@@ -182,7 +182,7 @@ class PluginEnableOperator(operator.CommandOperator):
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_status(plugin_name, True):
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
@@ -210,7 +210,7 @@ class PluginDisableOperator(operator.CommandOperator):
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_status(plugin_name, False):
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))

View File

@@ -4,6 +4,8 @@ import logging
import asyncio
import threading
import traceback
import enum
import sys
from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
@@ -21,8 +23,9 @@ from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
from ..utils import logcache
from ..utils import logcache, ip
from . import taskmgr
from . import entities as core_entities
class Application:
@@ -104,24 +107,84 @@ class Application:
pass
async def run(self):
await self.plugin_mgr.initialize_plugins()
try:
await self.plugin_mgr.initialize_plugins()
# 后续可能会允许动态重启其他任务
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
async def never_ending():
while True:
await asyncio.sleep(1)
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager")
self.task_mgr.create_task(self.ctrl.run(), name="query-controller")
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller")
self.task_mgr.create_task(never_ending(), name="never-ending-task")
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION])
await self.print_web_access_info()
await self.task_mgr.wait_all()
except asyncio.CancelledError:
pass
except Exception as e:
self.logger.error(f"应用运行致命异常: {e}")
self.logger.debug(f"Traceback: {traceback.format_exc()}")
async def print_web_access_info(self):
"""打印访问 webui 的提示"""
import socket
host_ip = socket.gethostbyname(socket.gethostname())
public_ip = await ip.get_myip()
port = self.system_cfg.data['http-api']['port']
tips = f"""
=======================================
✨ 您可通过以下方式访问管理面板
🏠 本地地址http://{host_ip}:{port}/
🌐 公网地址http://{public_ip}:{port}/
📌 如果您在容器中运行此程序,请确保容器的 {port} 端口已对外暴露
🔗 若要使用公网地址访问,请阅读以下须知
1. 公网地址仅供参考,请以您的主机公网 IP 为准;
2. 要使用公网地址访问,请确保您的主机具有公网 IP并且系统防火墙已放行 {port} 端口;
🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
=======================================
""".strip()
for line in tips.split("\n"):
self.logger.info(line)
async def reload(
self,
scope: core_entities.LifecycleControlScope,
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info("执行热重载 scope="+scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info("执行热重载 scope="+scope)
await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()):
if mod.startswith("plugins."):
del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self)
await self.plugin_mgr.initialize()
await self.plugin_mgr.initialize_plugins()
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
case _:
pass

View File

@@ -53,13 +53,17 @@ async def main(loop: asyncio.AbstractEventLoop):
# 挂系统信号处理
import signal
ap: app.Application
def signal_handler(sig, frame):
print("[Signal] 程序退出.")
# ap.shutdown()
os._exit(0)
signal.signal(signal.SIGINT, signal_handler)
app_inst = await make_app(loop)
ap = app_inst
await app_inst.run()
except Exception as e:
traceback.print_exc()

View File

@@ -6,7 +6,7 @@ required_deps = {
"anthropic": "anthropic",
"colorlog": "colorlog",
"aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy",
"botpy": "qq-botpy-rc",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"tiktoken": "tiktoken",

View File

@@ -17,6 +17,14 @@ from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities
class LifecycleControlScope(enum.Enum):
APPLICATION = "application"
PLATFORM = "platform"
PLUGIN = "plugin"
class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""

View File

@@ -6,6 +6,7 @@ import datetime
import traceback
from . import app
from . import entities as core_entities
class TaskContext:
@@ -71,7 +72,7 @@ class TaskWrapper:
task_type: str = "system" # 任务类型: system 或 user
"""任务类型"""
kind: str = "system_task"
kind: str = "system_task" # 由发起者确定任务种类,通常同质化的任务种类相同
"""任务种类"""
name: str = ""
@@ -92,6 +93,9 @@ class TaskWrapper:
ap: app.Application
"""应用实例"""
scopes: list[core_entities.LifecycleControlScope]
"""任务所属生命周期控制范围"""
def __init__(
self,
ap: app.Application,
@@ -101,6 +105,7 @@ class TaskWrapper:
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
):
self.id = TaskWrapper._id_index
TaskWrapper._id_index += 1
@@ -112,6 +117,7 @@ class TaskWrapper:
self.name = name
self.label = label if label != "" else name
self.task.set_name(name)
self.scopes = scopes
def assume_exception(self):
try:
@@ -145,6 +151,7 @@ class TaskWrapper:
"kind": self.kind,
"name": self.name,
"label": self.label,
"scopes": [scope.value for scope in self.scopes],
"task_context": self.task_context.to_dict(),
"runtime": {
"done": self.task.done(),
@@ -154,6 +161,9 @@ class TaskWrapper:
"result": self.assume_result().__str__() if self.assume_result() is not None else None,
},
}
def cancel(self):
self.task.cancel()
class AsyncTaskManager:
@@ -177,8 +187,9 @@ class AsyncTaskManager:
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context)
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
self.tasks.append(wrapper)
return wrapper
@@ -189,8 +200,9 @@ class AsyncTaskManager:
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
return self.create_task(coro, "user", kind, name, label, context)
return self.create_task(coro, "user", kind, name, label, context, scopes)
async def wait_all(self):
await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True)
@@ -214,3 +226,10 @@ class AsyncTaskManager:
if t.id == id:
return t
return None
def cancel_by_scope(self, scope: core_entities.LifecycleControlScope):
for wrapper in self.tasks:
if not wrapper.task.done() and scope in wrapper.scopes:
wrapper.task.cancel()

View File

@@ -4,7 +4,6 @@ import asyncio
import typing
import traceback
from ..core import app, entities
from . import entities as pipeline_entities
from ..plugin import events
@@ -59,13 +58,11 @@ class Controller:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
# task = asyncio.create_task(_process_query(selected_query))
# self.ap.asyncio_tasks.append(task)
self.ap.task_mgr.create_task(
_process_query(selected_query),
kind="query",
name=f"query-{selected_query.query_id}",
scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM],
)
except Exception as e:
@@ -166,6 +163,23 @@ class Controller:
async def process_query(self, query: entities.Query):
"""处理请求
"""
# ======== 触发 MessageReceived 事件 ========
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query
)
)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f"Processing query {query}")
try:
@@ -173,7 +187,6 @@ class Controller:
except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")

View File

@@ -37,76 +37,40 @@ class PlatformManager:
async def initialize(self):
from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy
from .sources import nakuru, aiocqhttp, qqbotpy
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived(
launcher_type='person',
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_chain=event.message_chain,
query=None
)
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
if not event_ctx.is_prevented_default():
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived(
launcher_type='person',
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_chain=event.message_chain,
query=None
)
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
if not event_ctx.is_prevented_default():
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.GroupMessageReceived(
launcher_type='group',
launcher_id=event.group.id,
sender_id=event.sender.id,
message_chain=event.message_chain,
query=None
)
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
if not event_ctx.is_prevented_default():
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
index = 0
@@ -174,24 +138,30 @@ class PlatformManager:
try:
tasks = []
for adapter in self.adapters:
async def exception_wrapper(adapter):
async def exception_wrapper(adapter: msadapter.MessageSourceAdapter):
try:
await adapter.run_async()
except Exception as e:
if isinstance(e, asyncio.CancelledError):
return
self.ap.logger.error('平台适配器运行出错: ' + str(e))
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
tasks.append(exception_wrapper(adapter))
for task in tasks:
# async_task = asyncio.create_task(task)
# self.ap.asyncio_tasks.append(async_task)
self.ap.task_mgr.create_task(
task,
kind="platform-adapter",
name=f"platform-adapter-{adapter.name}",
scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM],
)
except Exception as e:
self.ap.logger.error('平台适配器运行出错: ' + str(e))
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
async def shutdown(self):
for adapter in self.adapters:
await adapter.kill()
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)

View File

@@ -328,5 +328,5 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
while True:
await asyncio.sleep(1)
def kill(self) -> bool:
async def kill(self) -> bool:
return False

View File

@@ -21,7 +21,6 @@ from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
class OfficialGroupMessage(platform_events.GroupMessage):
pass
@@ -588,8 +587,12 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
self.member_openid_mapping, self.group_openid_mapping
)
self.ap.logger.info("运行 QQ 官方适配器")
await self.bot.start(**self.cfg)
self.cfg['ret_coro'] = True
def kill(self) -> bool:
return False
self.ap.logger.info("运行 QQ 官方适配器")
await (await self.bot.start(**self.cfg))
async def kill(self) -> bool:
if not self.bot.is_closed():
await self.bot.close()
return True

View File

@@ -1,121 +0,0 @@
# import asyncio
# import typing
# from .. import adapter as adapter_model
# from ...core import app
# @adapter_model.adapter_class("yiri-mirai")
# class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
# """YiriMirai适配器"""
# bot: mirai.Mirai
# def __init__(self, config: dict, ap: app.Application):
# """初始化YiriMirai的对象"""
# self.ap = ap
# self.config = config
# if 'adapter' not in config or \
# config['adapter'] == 'WebSocketAdapter':
# self.bot = mirai.Mirai(
# qq=config['qq'],
# adapter=mirai.WebSocketAdapter(
# host=config['host'],
# port=config['port'],
# verify_key=config['verifyKey']
# )
# )
# elif config['adapter'] == 'HTTPAdapter':
# self.bot = mirai.Mirai(
# qq=config['qq'],
# adapter=mirai.HTTPAdapter(
# host=config['host'],
# port=config['port'],
# verify_key=config['verifyKey']
# )
# )
# else:
# raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
# async def send_message(
# self,
# target_type: str,
# target_id: str,
# message: mirai.MessageChain
# ):
# """发送消息
# Args:
# target_type (str): 目标类型,`person`或`group`
# target_id (str): 目标ID
# message (mirai.MessageChain): YiriMirai库的消息链
# """
# task = None
# if target_type == 'person':
# task = self.bot.send_friend_message(int(target_id), message)
# elif target_type == 'group':
# task = self.bot.send_group_message(int(target_id), message)
# else:
# raise Exception('Unknown target type: ' + target_type)
# await task
# async def reply_message(
# self,
# message_source: mirai.MessageEvent,
# message: mirai.MessageChain,
# quote_origin: bool = False
# ):
# """回复消息
# Args:
# message_source (mirai.MessageEvent): YiriMirai消息源事件
# message (mirai.MessageChain): YiriMirai库的消息链
# quote_origin (bool, optional): 是否引用原消息. Defaults to False.
# """
# await self.bot.send(message_source, message, quote_origin)
# async def is_muted(self, group_id: int) -> bool:
# result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
# if result.mute_time_remaining > 0:
# return True
# return False
# def register_listener(
# self,
# event_type: typing.Type[mirai.Event],
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
# ):
# """注册事件监听器
# Args:
# event_type (typing.Type[mirai.Event]): YiriMirai事件类型
# callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
# """
# async def wrapper(event: mirai.Event):
# await callback(event, self)
# self.bot.on(event_type)(wrapper)
# def unregister_listener(
# self,
# event_type: typing.Type[mirai.Event],
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
# ):
# """注销事件监听器
# Args:
# event_type (typing.Type[mirai.Event]): YiriMirai事件类型
# callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
# """
# assert isinstance(self.bot, mirai.Mirai)
# bus = self.bot.bus
# assert isinstance(bus, mirai.models.bus.ModelEventBus)
# bus.unsubscribe(event_type, callback)
# async def run_async(self):
# self.bot_account_id = self.bot.qq
# return await MiraiRunner(self.bot)._run()
# async def kill(self) -> bool:
# return False

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import typing
import abc
import pydantic
import enum
from . import events
from ..provider.tools import entities as tools_entities
@@ -85,10 +86,19 @@ class BasePlugin(metaclass=abc.ABCMeta):
"""应用程序对象"""
def __init__(self, host: APIHost):
"""初始化阶段被调用"""
self.host = host
async def initialize(self):
"""初始化插件"""
"""初始化阶段被调用"""
pass
async def destroy(self):
"""释放/禁用插件时被调用"""
pass
def __del__(self):
"""释放/禁用插件时被调用"""
pass
@@ -247,6 +257,16 @@ class EventContext:
EventContext.eid += 1
class RuntimeContainerStatus(enum.Enum):
"""插件容器状态"""
MOUNTED = "mounted"
"""已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态"""
INITIALIZED = "initialized"
"""已初始化"""
class RuntimeContainer(pydantic.BaseModel):
"""运行时的插件容器
@@ -294,6 +314,9 @@ class RuntimeContainer(pydantic.BaseModel):
content_functions: list[tools_entities.LLMFunction] = []
"""内容函数"""
status: RuntimeContainerStatus = RuntimeContainerStatus.MOUNTED
"""插件状态"""
class Config:
arbitrary_types_allowed = True
@@ -318,9 +341,6 @@ class RuntimeContainer(pydantic.BaseModel):
self.priority = setting['priority']
self.enabled = setting['enabled']
for function in self.content_functions:
function.enable = self.enabled
def model_dump(self, *args, **kwargs):
return {
'name': self.plugin_name,
@@ -342,9 +362,9 @@ class RuntimeContainer(pydantic.BaseModel):
'human_desc': function.human_desc,
'description': function.description,
'parameters': function.parameters,
'enable': function.enable,
'func': function.func.__name__,
}
for function in self.content_functions
],
'status': self.status.value,
}

View File

@@ -13,13 +13,16 @@ class PluginLoader(metaclass=abc.ABCMeta):
ap: app.Application
plugins: list[context.RuntimeContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.plugins = []
async def initialize(self):
pass
@abc.abstractmethod
async def load_plugins(self) -> list[context.RuntimeContainer]:
async def load_plugins(self):
pass

View File

@@ -20,7 +20,14 @@ class PluginLoader(loader.PluginLoader):
_current_container: context.RuntimeContainer = None
containers: list[context.RuntimeContainer] = []
plugins: list[context.RuntimeContainer] = []
def __init__(self, ap):
self.ap = ap
self.plugins = []
self._current_pkg_path = ''
self._current_module_path = ''
self._current_container = None
async def initialize(self):
"""初始化"""
@@ -77,8 +84,10 @@ class PluginLoader(loader.PluginLoader):
}
# 把 ctx.event 所有的属性都放到 args 里
for k, v in ctx.event.dict().items():
args[k] = v
# for k, v in ctx.event.dict().items():
# args[k] = v
for attr_name in ctx.event.__dict__.keys():
args[attr_name] = getattr(ctx.event, attr_name)
func(plugin, **args)
@@ -113,7 +122,6 @@ class PluginLoader(loader.PluginLoader):
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=handler,
)
@@ -153,7 +161,6 @@ class PluginLoader(loader.PluginLoader):
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=func,
)
@@ -189,15 +196,13 @@ class PluginLoader(loader.PluginLoader):
importlib.import_module(module.__name__ + "." + item.name)
if self._current_container is not None:
self.containers.append(self._current_container)
self.plugins.append(self._current_container)
self.ap.logger.debug(f'插件 {self._current_container} 已加载')
except:
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
traceback.print_exc()
async def load_plugins(self) -> list[context.RuntimeContainer]:
async def load_plugins(self):
"""加载插件
"""
await self._walk_plugin_path(__import__("plugins", fromlist=[""]))
return self.containers

View File

@@ -22,7 +22,22 @@ class PluginManager:
api_host: context.APIHost
plugins: list[context.RuntimeContainer]
def plugins(
self,
enabled: bool=None,
status: context.RuntimeContainerStatus=None,
) -> list[context.RuntimeContainer]:
"""获取插件列表
"""
plugins = self.loader.plugins
if enabled is not None:
plugins = [plugin for plugin in plugins if plugin.enabled == enabled]
if status is not None:
plugins = [plugin for plugin in plugins if plugin.status == status]
return plugins
def __init__(self, ap: app.Application):
self.ap = ap
@@ -30,7 +45,6 @@ class PluginManager:
self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap)
self.plugins = []
async def initialize(self):
await self.loader.initialize()
@@ -41,27 +55,58 @@ class PluginManager:
setattr(models, 'require_ver', self.api_host.require_ver)
async def load_plugins(self):
self.plugins = await self.loader.load_plugins()
await self.loader.load_plugins()
await self.setting.sync_setting(self.plugins)
await self.setting.sync_setting(self.loader.plugins)
# 按优先级倒序
self.plugins.sort(key=lambda x: x.priority, reverse=True)
self.loader.plugins.sort(key=lambda x: x.priority, reverse=True)
self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugins}')
self.ap.logger.debug(f'优先级排序后的插件列表 {self.loader.plugins}')
async def initialize_plugin(self, plugin: context.RuntimeContainer):
self.ap.logger.debug(f'初始化插件 {plugin.plugin_name}')
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
plugin.status = context.RuntimeContainerStatus.INITIALIZED
async def initialize_plugins(self):
for plugin in self.plugins:
for plugin in self.plugins():
if not plugin.enabled:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 未启用,跳过初始化')
continue
try:
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
await self.initialize_plugin(plugin)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
self.ap.logger.exception(e)
continue
async def destroy_plugin(self, plugin: context.RuntimeContainer):
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
return
self.ap.logger.debug(f'释放插件 {plugin.plugin_name}')
plugin.plugin_inst.__del__()
await plugin.plugin_inst.destroy()
plugin.plugin_inst = None
plugin.status = context.RuntimeContainerStatus.MOUNTED
async def destroy_plugins(self):
for plugin in self.plugins():
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 未初始化,跳过释放')
continue
try:
await self.destroy_plugin(plugin)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 释放失败: {e}')
self.ap.logger.exception(e)
continue
async def install_plugin(
self,
plugin_source: str,
@@ -80,6 +125,9 @@ class PluginManager:
}
)
task_context.trace('重载插件..', 'reload-plugin')
await self.ap.reload(scope='plugin')
async def uninstall_plugin(
self,
plugin_name: str,
@@ -87,10 +135,15 @@ class PluginManager:
):
"""卸载插件
"""
await self.installer.uninstall_plugin(plugin_name, task_context)
plugin_container = self.get_plugin_by_name(plugin_name)
if plugin_container is None:
raise ValueError(f'插件 {plugin_name} 不存在')
await self.destroy_plugin(plugin_container)
await self.installer.uninstall_plugin(plugin_name, task_context)
await self.ap.ctr_mgr.plugin.post_remove_record(
{
"name": plugin_name,
@@ -100,6 +153,9 @@ class PluginManager:
}
)
task_context.trace('重载插件..', 'reload-plugin')
await self.ap.reload(scope='plugin')
async def update_plugin(
self,
plugin_name: str,
@@ -123,11 +179,13 @@ class PluginManager:
new_version="HEAD"
)
task_context.trace('重载插件..', 'reload-plugin')
await self.ap.reload(scope='plugin')
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
"""通过插件名获取插件
"""
for plugin in self.plugins:
for plugin in self.plugins():
if plugin.plugin_name == plugin_name:
return plugin
return None
@@ -143,30 +201,32 @@ class PluginManager:
emitted_plugins: list[context.RuntimeContainer] = []
for plugin in self.plugins:
if plugin.enabled:
if event.__class__ in plugin.event_handlers:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}')
is_prevented_default_before_call = ctx.is_prevented_default()
for plugin in self.plugins(
enabled=True,
status=context.RuntimeContainerStatus.INITIALIZED
):
if event.__class__ in plugin.event_handlers:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}')
is_prevented_default_before_call = ctx.is_prevented_default()
try:
await plugin.event_handlers[event.__class__](
plugin.plugin_inst,
ctx
)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
emitted_plugins.append(plugin)
try:
await plugin.event_handlers[event.__class__](
plugin.plugin_inst,
ctx
)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
emitted_plugins.append(plugin)
if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
if ctx.is_prevented_postorder():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
break
if ctx.is_prevented_postorder():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
break
for key in ctx.__return_value__.keys():
if hasattr(ctx.event, key):
@@ -191,16 +251,22 @@ class PluginManager:
return ctx
async def update_plugin_status(self, plugin_name: str, new_status: bool):
async def update_plugin_switch(self, plugin_name: str, new_status: bool):
if self.get_plugin_by_name(plugin_name) is not None:
for plugin in self.plugins:
for plugin in self.plugins():
if plugin.plugin_name == plugin_name:
plugin.enabled = new_status
for func in plugin.content_functions:
func.enable = new_status
if plugin.enabled == new_status:
return False
await self.setting.dump_container_setting(self.plugins)
# 初始化/释放插件
if new_status:
await self.initialize_plugin(plugin)
else:
await self.destroy_plugin(plugin)
plugin.enabled = new_status
await self.setting.dump_container_setting(self.loader.plugins)
break
@@ -214,11 +280,11 @@ class PluginManager:
plugin_name = plugin.get('name')
plugin_priority = plugin.get('priority')
for plugin in self.plugins:
for plugin in self.loader.plugins:
if plugin.plugin_name == plugin_name:
plugin.priority = plugin_priority
break
self.plugins.sort(key=lambda x: x.priority, reverse=True)
self.loader.plugins.sort(key=lambda x: x.priority, reverse=True)
await self.setting.dump_container_setting(self.plugins)
await self.setting.dump_container_setting(self.loader.plugins)

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from ...core import app, entities as core_entities
from ...plugin import context as plugin_context
class SessionManager:
@@ -51,7 +52,10 @@ class SessionManager:
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
use_funcs=await self.ap.tool_mgr.get_all_functions(),
use_funcs=await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED,
),
)
session.conversations.append(conversation)
session.using_conversation = conversation

View File

@@ -20,8 +20,6 @@ class LLMFunction(pydantic.BaseModel):
description: str
"""给LLM识别的函数描述"""
enable: typing.Optional[bool] = True
parameters: dict
func: typing.Callable

View File

@@ -20,28 +20,25 @@ class ToolManager:
async def initialize(self):
pass
async def get_function(self, name: str) -> entities.LLMFunction:
"""获取函数"""
for function in await self.get_all_functions():
if function.name == name:
return function
return None
async def get_function_and_plugin(
self, name: str
) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
"""获取函数和插件"""
for plugin in self.ap.plugin_mgr.plugins:
"""获取函数和插件实例"""
for plugin in self.ap.plugin_mgr.plugins(
enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED
):
for function in plugin.content_functions:
if function.name == name:
return function, plugin.plugin_inst
return None, None
async def get_all_functions(self) -> list[entities.LLMFunction]:
async def get_all_functions(self, plugin_enabled: bool=None, plugin_status: plugin_context.RuntimeContainerStatus=None) -> list[entities.LLMFunction]:
"""获取所有函数"""
all_functions: list[entities.LLMFunction] = []
for plugin in self.ap.plugin_mgr.plugins:
for plugin in self.ap.plugin_mgr.plugins(
enabled=plugin_enabled, status=plugin_status
):
all_functions.extend(plugin.content_functions)
return all_functions
@@ -51,16 +48,15 @@ class ToolManager:
tools = []
for function in use_funcs:
if function.enable:
function_schema = {
"type": "function",
"function": {
"name": function.name,
"description": function.description,
"parameters": function.parameters,
},
}
tools.append(function_schema)
function_schema = {
"type": "function",
"function": {
"name": function.name,
"description": function.description,
"parameters": function.parameters,
},
}
tools.append(function_schema)
return tools
@@ -92,13 +88,12 @@ class ToolManager:
tools = []
for function in use_funcs:
if function.enable:
function_schema = {
"name": function.name,
"description": function.description,
"input_schema": function.parameters,
}
tools.append(function_schema)
function_schema = {
"name": function.name,
"description": function.description,
"input_schema": function.parameters,
}
tools.append(function_schema)
return tools

9
pkg/utils/ip.py Normal file
View File

@@ -0,0 +1,9 @@
import aiohttp
async def get_myip() -> str:
try:
async with aiohttp.ClientSession() as session:
async with session.get("https://ip.useragentinfo.com/myip") as response:
return await response.text()
except Exception as e:
return '0.0.0.0'

View File

@@ -3,7 +3,7 @@ openai>1.0.0
anthropic
colorlog~=6.6.0
aiocqhttp
qq-botpy
qq-botpy-rc
nakuru-project-idk
Pillow
tiktoken

View File

@@ -1,13 +1,5 @@
{
"platform-adapters": [
{
"adapter": "yiri-mirai",
"enable": false,
"host": "127.0.0.1",
"port": 8080,
"verifyKey": "yirimirai",
"qq": 123456789
},
{
"adapter": "nakuru",
"enable": false,

View File

@@ -9,43 +9,6 @@
"items": {
"type": "object",
"oneOf": [
{
"title": "YiriMirai 适配器",
"description": "用于接入 Mirai",
"properties": {
"adapter": {
"type": "string",
"const": "yiri-mirai"
},
"enable": {
"type": "boolean",
"default": false,
"description": "是否启用此适配器",
"layout": {
"comp": "switch",
"props": {
"color": "primary"
}
}
},
"host": {
"type": "string",
"default": "127.0.0.1"
},
"port": {
"type": "integer",
"default": 8080
},
"verifyKey": {
"type": "string",
"default": "yirimirai"
},
"qq": {
"type": "integer",
"default": 123456789
}
}
},
{
"title": "Nakuru 适配器",
"description": "用于接入 go-cqhttp",

View File

@@ -67,6 +67,17 @@
</v-list-item-title>
</v-list-item>
<v-list-item @click="reload('platform')">
<v-list-item-title>
重载消息平台
</v-list-item-title>
</v-list-item>
<v-list-item @click="reload('plugin')">
<v-list-item-title>
重载插件
</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
</v-list-item>
@@ -137,6 +148,30 @@ function openDocs() {
window.open('https://docs.langbot.app', '_blank')
}
const reloadScopeLabel = {
'platform': "消息平台",
'plugin': "插件"
}
function reload(scope) {
let label = reloadScopeLabel[scope]
proxy.$axios.post('/system/reload',
{ scope: scope },
{ headers: { 'Content-Type': 'application/json' } }
).then(response => {
if (response.data.code === 0) {
success(label+'已重载')
// 关闭菜单
} else {
error(label+'重载失败:' + response.data.message)
}
}).catch(err => {
error(label+'重载失败:' + err)
})
}
const aboutDialogShow = ref(false)
function showAboutDialog() {
@@ -162,10 +197,6 @@ function closeAboutDialog() {
margin-left: -0.2rem;
}
#logo-img {
/* margin-left: -0.2rem; */
}
#logo-list-item {
margin-top: 0.5rem;
margin-bottom: 1.5rem;