mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
Merge pull request #920 from RockChinQ/feat/lifetime-controlling
Feat: 生命周期和热重载
This commit is contained in:
13
main.py
13
main.py
@@ -3,13 +3,14 @@
|
||||
# QChatGPT/main.py
|
||||
|
||||
asciiart = r"""
|
||||
___ ___ _ _ ___ ___ _____
|
||||
/ _ \ / __| |_ __ _| |_ / __| _ \_ _|
|
||||
| (_) | (__| ' \/ _` | _| (_ | _/ | |
|
||||
\__\_\\___|_||_\__,_|\__|\___|_| |_|
|
||||
_ ___ _
|
||||
| | __ _ _ _ __ _| _ ) ___| |_
|
||||
| |__/ _` | ' \/ _` | _ \/ _ \ _|
|
||||
|____\__,_|_||_\__, |___/\___/\__|
|
||||
|___/
|
||||
|
||||
⭐️开源地址: https://github.com/RockChinQ/QChatGPT
|
||||
📖文档地址: https://q.rkcn.top
|
||||
⭐️开源地址: https://github.com/RockChinQ/LangBot
|
||||
📖文档地址: https://docs.langbot.app
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'])
|
||||
async def _() -> str:
|
||||
plugins = self.ap.plugin_mgr.plugins
|
||||
plugins = self.ap.plugin_mgr.plugins()
|
||||
|
||||
plugins_data = [plugin.model_dump() for plugin in plugins]
|
||||
|
||||
@@ -27,7 +27,7 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
data = await quart.request.json
|
||||
target_enabled = data.get('target_enabled')
|
||||
await self.ap.plugin_mgr.update_plugin_status(plugin_name, target_enabled)
|
||||
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
|
||||
return self.success()
|
||||
|
||||
@self.route('/<author>/<plugin_name>/update', methods=['POST'])
|
||||
|
||||
@@ -39,3 +39,25 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
return self.http_status(404, 404, "Task not found")
|
||||
|
||||
return self.success(data=task.to_dict())
|
||||
|
||||
@self.route('/reload', methods=['POST'])
|
||||
async def _() -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
scope = json_data.get("scope")
|
||||
|
||||
await self.ap.reload(
|
||||
scope=scope
|
||||
)
|
||||
return self.success()
|
||||
|
||||
@self.route('/_debug/exec', methods=['POST'])
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, "Forbidden")
|
||||
|
||||
py_code = await quart.request.data
|
||||
|
||||
ap = self.ap
|
||||
|
||||
return self.success(data=exec(py_code, {"ap": ap}))
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
import quart
|
||||
import quart_cors
|
||||
|
||||
from ....core import app
|
||||
from ....core import app, entities as core_entities
|
||||
from .groups import logs, system, settings, plugins, stats
|
||||
from . import group
|
||||
|
||||
@@ -32,15 +32,26 @@ class HTTPController:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def exception_handler(*args, **kwargs):
|
||||
try:
|
||||
await self.quart_app.run_task(
|
||||
*args, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"启动 HTTP 服务失败: {e}")
|
||||
|
||||
self.ap.task_mgr.create_task(
|
||||
self.quart_app.run_task(
|
||||
exception_handler(
|
||||
host=self.ap.system_cfg.data["http-api"]["host"],
|
||||
port=self.ap.system_cfg.data["http-api"]["port"],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
),
|
||||
name="http-api-quart",
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
# await asyncio.sleep(5)
|
||||
|
||||
async def register_routes(self) -> None:
|
||||
|
||||
@self.quart_app.route("/healthz")
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from ...core import app
|
||||
from ...core import app, entities as core_entities
|
||||
|
||||
|
||||
class APIGroup(metaclass=abc.ABCMeta):
|
||||
@@ -65,14 +65,12 @@ class APIGroup(metaclass=abc.ABCMeta):
|
||||
**kwargs,
|
||||
) -> asyncio.Task:
|
||||
"""执行请求"""
|
||||
# task = asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
|
||||
|
||||
# self.ap.asyncio_tasks.append(task)
|
||||
|
||||
return self.ap.task_mgr.create_task(
|
||||
self._do(method, path, data, params, headers, **kwargs),
|
||||
kind="telemetry-operation",
|
||||
name=f"{method} {path}",
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
).task
|
||||
|
||||
def gen_rid(self):
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from .. import operator, entities, cmdmgr
|
||||
from ...plugin import context as plugin_context
|
||||
|
||||
|
||||
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
|
||||
@@ -9,16 +10,18 @@ class FuncOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> AsyncGenerator[entities.CommandReturn, None]:
|
||||
reply_str = "当前已加载的内容函数: \n\n"
|
||||
reply_str = "当前已启用的内容函数: \n\n"
|
||||
|
||||
index = 1
|
||||
|
||||
all_functions = await self.ap.tool_mgr.get_all_functions()
|
||||
all_functions = await self.ap.tool_mgr.get_all_functions(
|
||||
plugin_enabled=True,
|
||||
plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED,
|
||||
)
|
||||
|
||||
for func in all_functions:
|
||||
reply_str += "{}. {}{}:\n{}\n\n".format(
|
||||
reply_str += "{}. {}:\n{}\n\n".format(
|
||||
index,
|
||||
("(已禁用) " if not func.enable else ""),
|
||||
func.name,
|
||||
func.description,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ class PluginOperator(operator.CommandOperator):
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
plugin_list = self.ap.plugin_mgr.plugins
|
||||
plugin_list = self.ap.plugin_mgr.plugins()
|
||||
reply_str = "所有插件({}):\n".format(len(plugin_list))
|
||||
idx = 0
|
||||
for plugin in plugin_list:
|
||||
@@ -110,7 +110,7 @@ class PluginUpdateAllOperator(operator.CommandOperator):
|
||||
try:
|
||||
plugins = [
|
||||
p.plugin_name
|
||||
for p in self.ap.plugin_mgr.plugins
|
||||
for p in self.ap.plugin_mgr.plugins()
|
||||
]
|
||||
|
||||
if plugins:
|
||||
@@ -182,7 +182,7 @@ class PluginEnableOperator(operator.CommandOperator):
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if await self.ap.plugin_mgr.update_plugin_status(plugin_name, True):
|
||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
|
||||
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
|
||||
@@ -210,7 +210,7 @@ class PluginDisableOperator(operator.CommandOperator):
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if await self.ap.plugin_mgr.update_plugin_status(plugin_name, False):
|
||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
|
||||
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
|
||||
|
||||
@@ -4,6 +4,8 @@ import logging
|
||||
import asyncio
|
||||
import threading
|
||||
import traceback
|
||||
import enum
|
||||
import sys
|
||||
|
||||
from ..platform import manager as im_mgr
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
@@ -21,8 +23,9 @@ from ..pipeline import controller, stagemgr
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
|
||||
from ..persistence import mgr as persistencemgr
|
||||
from ..api.http.controller import main as http_controller
|
||||
from ..utils import logcache
|
||||
from ..utils import logcache, ip
|
||||
from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
|
||||
|
||||
class Application:
|
||||
@@ -104,24 +107,84 @@ class Application:
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
|
||||
try:
|
||||
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
# 后续可能会允许动态重启其他任务
|
||||
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
|
||||
async def never_ending():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager")
|
||||
self.task_mgr.create_task(self.ctrl.run(), name="query-controller")
|
||||
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller")
|
||||
self.task_mgr.create_task(never_ending(), name="never-ending-task")
|
||||
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
|
||||
self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
|
||||
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
|
||||
self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION])
|
||||
|
||||
await self.print_web_access_info()
|
||||
await self.task_mgr.wait_all()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.logger.error(f"应用运行致命异常: {e}")
|
||||
self.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def print_web_access_info(self):
|
||||
"""打印访问 webui 的提示"""
|
||||
import socket
|
||||
|
||||
host_ip = socket.gethostbyname(socket.gethostname())
|
||||
|
||||
public_ip = await ip.get_myip()
|
||||
|
||||
port = self.system_cfg.data['http-api']['port']
|
||||
|
||||
tips = f"""
|
||||
=======================================
|
||||
✨ 您可通过以下方式访问管理面板
|
||||
|
||||
🏠 本地地址:http://{host_ip}:{port}/
|
||||
🌐 公网地址:http://{public_ip}:{port}/
|
||||
|
||||
📌 如果您在容器中运行此程序,请确保容器的 {port} 端口已对外暴露
|
||||
🔗 若要使用公网地址访问,请阅读以下须知
|
||||
1. 公网地址仅供参考,请以您的主机公网 IP 为准;
|
||||
2. 要使用公网地址访问,请确保您的主机具有公网 IP,并且系统防火墙已放行 {port} 端口;
|
||||
|
||||
🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
|
||||
=======================================
|
||||
""".strip()
|
||||
for line in tips.split("\n"):
|
||||
self.logger.info(line)
|
||||
|
||||
async def reload(
|
||||
self,
|
||||
scope: core_entities.LifecycleControlScope,
|
||||
):
|
||||
match scope:
|
||||
case core_entities.LifecycleControlScope.PLATFORM.value:
|
||||
self.logger.info("执行热重载 scope="+scope)
|
||||
await self.platform_mgr.shutdown()
|
||||
|
||||
self.platform_mgr = im_mgr.PlatformManager(self)
|
||||
|
||||
await self.platform_mgr.initialize()
|
||||
|
||||
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
|
||||
case core_entities.LifecycleControlScope.PLUGIN.value:
|
||||
self.logger.info("执行热重载 scope="+scope)
|
||||
await self.plugin_mgr.destroy_plugins()
|
||||
|
||||
# 删除 sys.module 中所有的 plugins/* 下的模块
|
||||
for mod in list(sys.modules.keys()):
|
||||
if mod.startswith("plugins."):
|
||||
del sys.modules[mod]
|
||||
|
||||
self.plugin_mgr = plugin_mgr.PluginManager(self)
|
||||
await self.plugin_mgr.initialize()
|
||||
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
|
||||
await self.plugin_mgr.load_plugins()
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
case _:
|
||||
pass
|
||||
|
||||
@@ -53,13 +53,17 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
# 挂系统信号处理
|
||||
import signal
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
print("[Signal] 程序退出.")
|
||||
# ap.shutdown()
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
app_inst = await make_app(loop)
|
||||
ap = app_inst
|
||||
await app_inst.run()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -6,7 +6,7 @@ required_deps = {
|
||||
"anthropic": "anthropic",
|
||||
"colorlog": "colorlog",
|
||||
"aiocqhttp": "aiocqhttp",
|
||||
"botpy": "qq-botpy",
|
||||
"botpy": "qq-botpy-rc",
|
||||
"PIL": "pillow",
|
||||
"nakuru": "nakuru-project-idk",
|
||||
"tiktoken": "tiktoken",
|
||||
|
||||
@@ -17,6 +17,14 @@ from ..platform.types import events as platform_events
|
||||
from ..platform.types import entities as platform_entities
|
||||
|
||||
|
||||
|
||||
class LifecycleControlScope(enum.Enum):
|
||||
|
||||
APPLICATION = "application"
|
||||
PLATFORM = "platform"
|
||||
PLUGIN = "plugin"
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
"""一个请求的发起者类型"""
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import datetime
|
||||
import traceback
|
||||
|
||||
from . import app
|
||||
from . import entities as core_entities
|
||||
|
||||
|
||||
class TaskContext:
|
||||
@@ -71,7 +72,7 @@ class TaskWrapper:
|
||||
task_type: str = "system" # 任务类型: system 或 user
|
||||
"""任务类型"""
|
||||
|
||||
kind: str = "system_task"
|
||||
kind: str = "system_task" # 由发起者确定任务种类,通常同质化的任务种类相同
|
||||
"""任务种类"""
|
||||
|
||||
name: str = ""
|
||||
@@ -92,6 +93,9 @@ class TaskWrapper:
|
||||
ap: app.Application
|
||||
"""应用实例"""
|
||||
|
||||
scopes: list[core_entities.LifecycleControlScope]
|
||||
"""任务所属生命周期控制范围"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
@@ -101,6 +105,7 @@ class TaskWrapper:
|
||||
name: str = "",
|
||||
label: str = "",
|
||||
context: TaskContext = None,
|
||||
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
|
||||
):
|
||||
self.id = TaskWrapper._id_index
|
||||
TaskWrapper._id_index += 1
|
||||
@@ -112,6 +117,7 @@ class TaskWrapper:
|
||||
self.name = name
|
||||
self.label = label if label != "" else name
|
||||
self.task.set_name(name)
|
||||
self.scopes = scopes
|
||||
|
||||
def assume_exception(self):
|
||||
try:
|
||||
@@ -145,6 +151,7 @@ class TaskWrapper:
|
||||
"kind": self.kind,
|
||||
"name": self.name,
|
||||
"label": self.label,
|
||||
"scopes": [scope.value for scope in self.scopes],
|
||||
"task_context": self.task_context.to_dict(),
|
||||
"runtime": {
|
||||
"done": self.task.done(),
|
||||
@@ -154,6 +161,9 @@ class TaskWrapper:
|
||||
"result": self.assume_result().__str__() if self.assume_result() is not None else None,
|
||||
},
|
||||
}
|
||||
|
||||
def cancel(self):
|
||||
self.task.cancel()
|
||||
|
||||
|
||||
class AsyncTaskManager:
|
||||
@@ -177,8 +187,9 @@ class AsyncTaskManager:
|
||||
name: str = "",
|
||||
label: str = "",
|
||||
context: TaskContext = None,
|
||||
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
|
||||
) -> TaskWrapper:
|
||||
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context)
|
||||
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
|
||||
self.tasks.append(wrapper)
|
||||
return wrapper
|
||||
|
||||
@@ -189,8 +200,9 @@ class AsyncTaskManager:
|
||||
name: str = "",
|
||||
label: str = "",
|
||||
context: TaskContext = None,
|
||||
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
|
||||
) -> TaskWrapper:
|
||||
return self.create_task(coro, "user", kind, name, label, context)
|
||||
return self.create_task(coro, "user", kind, name, label, context, scopes)
|
||||
|
||||
async def wait_all(self):
|
||||
await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True)
|
||||
@@ -214,3 +226,10 @@ class AsyncTaskManager:
|
||||
if t.id == id:
|
||||
return t
|
||||
return None
|
||||
|
||||
def cancel_by_scope(self, scope: core_entities.LifecycleControlScope):
|
||||
for wrapper in self.tasks:
|
||||
|
||||
if not wrapper.task.done() and scope in wrapper.scopes:
|
||||
|
||||
wrapper.task.cancel()
|
||||
|
||||
@@ -4,7 +4,6 @@ import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
|
||||
from ..core import app, entities
|
||||
from . import entities as pipeline_entities
|
||||
from ..plugin import events
|
||||
@@ -59,13 +58,11 @@ class Controller:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
# task = asyncio.create_task(_process_query(selected_query))
|
||||
# self.ap.asyncio_tasks.append(task)
|
||||
self.ap.task_mgr.create_task(
|
||||
_process_query(selected_query),
|
||||
kind="query",
|
||||
name=f"query-{selected_query.query_id}",
|
||||
scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -166,6 +163,23 @@ class Controller:
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
|
||||
# ======== 触发 MessageReceived 事件 ========
|
||||
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_type(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
try:
|
||||
@@ -173,7 +187,6 @@ class Controller:
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}")
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
# traceback.print_exc()
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
|
||||
@@ -37,76 +37,40 @@ class PlatformManager:
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy
|
||||
from .sources import nakuru, aiocqhttp, qqbotpy
|
||||
|
||||
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PersonMessageReceived(
|
||||
launcher_type='person',
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_chain=event.message_chain,
|
||||
query=None
|
||||
)
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
if not event_ctx.is_prevented_default():
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PersonMessageReceived(
|
||||
launcher_type='person',
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_chain=event.message_chain,
|
||||
query=None
|
||||
)
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
if not event_ctx.is_prevented_default():
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.GroupMessageReceived(
|
||||
launcher_type='group',
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_chain=event.message_chain,
|
||||
query=None
|
||||
)
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
if not event_ctx.is_prevented_default():
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
index = 0
|
||||
|
||||
@@ -174,24 +138,30 @@ class PlatformManager:
|
||||
try:
|
||||
tasks = []
|
||||
for adapter in self.adapters:
|
||||
async def exception_wrapper(adapter):
|
||||
async def exception_wrapper(adapter: msadapter.MessageSourceAdapter):
|
||||
try:
|
||||
await adapter.run_async()
|
||||
except Exception as e:
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
return
|
||||
self.ap.logger.error('平台适配器运行出错: ' + str(e))
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
tasks.append(exception_wrapper(adapter))
|
||||
|
||||
for task in tasks:
|
||||
# async_task = asyncio.create_task(task)
|
||||
# self.ap.asyncio_tasks.append(async_task)
|
||||
self.ap.task_mgr.create_task(
|
||||
task,
|
||||
kind="platform-adapter",
|
||||
name=f"platform-adapter-{adapter.name}",
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error('平台适配器运行出错: ' + str(e))
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def shutdown(self):
|
||||
for adapter in self.adapters:
|
||||
await adapter.kill()
|
||||
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)
|
||||
@@ -328,5 +328,5 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def kill(self) -> bool:
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
@@ -21,7 +21,6 @@ from ...platform.types import events as platform_events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
|
||||
class OfficialGroupMessage(platform_events.GroupMessage):
|
||||
pass
|
||||
|
||||
@@ -588,8 +587,12 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
self.member_openid_mapping, self.group_openid_mapping
|
||||
)
|
||||
|
||||
self.ap.logger.info("运行 QQ 官方适配器")
|
||||
await self.bot.start(**self.cfg)
|
||||
self.cfg['ret_coro'] = True
|
||||
|
||||
def kill(self) -> bool:
|
||||
return False
|
||||
self.ap.logger.info("运行 QQ 官方适配器")
|
||||
await (await self.bot.start(**self.cfg))
|
||||
|
||||
async def kill(self) -> bool:
|
||||
if not self.bot.is_closed():
|
||||
await self.bot.close()
|
||||
return True
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
# import asyncio
|
||||
# import typing
|
||||
|
||||
|
||||
# from .. import adapter as adapter_model
|
||||
# from ...core import app
|
||||
|
||||
|
||||
# @adapter_model.adapter_class("yiri-mirai")
|
||||
# class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
|
||||
# """YiriMirai适配器"""
|
||||
# bot: mirai.Mirai
|
||||
|
||||
# def __init__(self, config: dict, ap: app.Application):
|
||||
# """初始化YiriMirai的对象"""
|
||||
# self.ap = ap
|
||||
# self.config = config
|
||||
# if 'adapter' not in config or \
|
||||
# config['adapter'] == 'WebSocketAdapter':
|
||||
# self.bot = mirai.Mirai(
|
||||
# qq=config['qq'],
|
||||
# adapter=mirai.WebSocketAdapter(
|
||||
# host=config['host'],
|
||||
# port=config['port'],
|
||||
# verify_key=config['verifyKey']
|
||||
# )
|
||||
# )
|
||||
# elif config['adapter'] == 'HTTPAdapter':
|
||||
# self.bot = mirai.Mirai(
|
||||
# qq=config['qq'],
|
||||
# adapter=mirai.HTTPAdapter(
|
||||
# host=config['host'],
|
||||
# port=config['port'],
|
||||
# verify_key=config['verifyKey']
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
|
||||
|
||||
# async def send_message(
|
||||
# self,
|
||||
# target_type: str,
|
||||
# target_id: str,
|
||||
# message: mirai.MessageChain
|
||||
# ):
|
||||
# """发送消息
|
||||
|
||||
# Args:
|
||||
# target_type (str): 目标类型,`person`或`group`
|
||||
# target_id (str): 目标ID
|
||||
# message (mirai.MessageChain): YiriMirai库的消息链
|
||||
# """
|
||||
# task = None
|
||||
# if target_type == 'person':
|
||||
# task = self.bot.send_friend_message(int(target_id), message)
|
||||
# elif target_type == 'group':
|
||||
# task = self.bot.send_group_message(int(target_id), message)
|
||||
# else:
|
||||
# raise Exception('Unknown target type: ' + target_type)
|
||||
|
||||
# await task
|
||||
|
||||
# async def reply_message(
|
||||
# self,
|
||||
# message_source: mirai.MessageEvent,
|
||||
# message: mirai.MessageChain,
|
||||
# quote_origin: bool = False
|
||||
# ):
|
||||
# """回复消息
|
||||
|
||||
# Args:
|
||||
# message_source (mirai.MessageEvent): YiriMirai消息源事件
|
||||
# message (mirai.MessageChain): YiriMirai库的消息链
|
||||
# quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
# """
|
||||
# await self.bot.send(message_source, message, quote_origin)
|
||||
|
||||
# async def is_muted(self, group_id: int) -> bool:
|
||||
# result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
|
||||
# if result.mute_time_remaining > 0:
|
||||
# return True
|
||||
# return False
|
||||
|
||||
# def register_listener(
|
||||
# self,
|
||||
# event_type: typing.Type[mirai.Event],
|
||||
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
# ):
|
||||
# """注册事件监听器
|
||||
|
||||
# Args:
|
||||
# event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
# callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
# """
|
||||
# async def wrapper(event: mirai.Event):
|
||||
# await callback(event, self)
|
||||
# self.bot.on(event_type)(wrapper)
|
||||
|
||||
# def unregister_listener(
|
||||
# self,
|
||||
# event_type: typing.Type[mirai.Event],
|
||||
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
# ):
|
||||
# """注销事件监听器
|
||||
|
||||
# Args:
|
||||
# event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
# callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
# """
|
||||
# assert isinstance(self.bot, mirai.Mirai)
|
||||
# bus = self.bot.bus
|
||||
# assert isinstance(bus, mirai.models.bus.ModelEventBus)
|
||||
|
||||
# bus.unsubscribe(event_type, callback)
|
||||
|
||||
# async def run_async(self):
|
||||
# self.bot_account_id = self.bot.qq
|
||||
# return await MiraiRunner(self.bot)._run()
|
||||
|
||||
# async def kill(self) -> bool:
|
||||
# return False
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import typing
|
||||
import abc
|
||||
import pydantic
|
||||
import enum
|
||||
|
||||
from . import events
|
||||
from ..provider.tools import entities as tools_entities
|
||||
@@ -85,10 +86,19 @@ class BasePlugin(metaclass=abc.ABCMeta):
|
||||
"""应用程序对象"""
|
||||
|
||||
def __init__(self, host: APIHost):
|
||||
"""初始化阶段被调用"""
|
||||
self.host = host
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化插件"""
|
||||
"""初始化阶段被调用"""
|
||||
pass
|
||||
|
||||
async def destroy(self):
|
||||
"""释放/禁用插件时被调用"""
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""释放/禁用插件时被调用"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -247,6 +257,16 @@ class EventContext:
|
||||
EventContext.eid += 1
|
||||
|
||||
|
||||
class RuntimeContainerStatus(enum.Enum):
|
||||
"""插件容器状态"""
|
||||
|
||||
MOUNTED = "mounted"
|
||||
"""已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态"""
|
||||
|
||||
INITIALIZED = "initialized"
|
||||
"""已初始化"""
|
||||
|
||||
|
||||
class RuntimeContainer(pydantic.BaseModel):
|
||||
"""运行时的插件容器
|
||||
|
||||
@@ -294,6 +314,9 @@ class RuntimeContainer(pydantic.BaseModel):
|
||||
content_functions: list[tools_entities.LLMFunction] = []
|
||||
"""内容函数"""
|
||||
|
||||
status: RuntimeContainerStatus = RuntimeContainerStatus.MOUNTED
|
||||
"""插件状态"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -318,9 +341,6 @@ class RuntimeContainer(pydantic.BaseModel):
|
||||
self.priority = setting['priority']
|
||||
self.enabled = setting['enabled']
|
||||
|
||||
for function in self.content_functions:
|
||||
function.enable = self.enabled
|
||||
|
||||
def model_dump(self, *args, **kwargs):
|
||||
return {
|
||||
'name': self.plugin_name,
|
||||
@@ -342,9 +362,9 @@ class RuntimeContainer(pydantic.BaseModel):
|
||||
'human_desc': function.human_desc,
|
||||
'description': function.description,
|
||||
'parameters': function.parameters,
|
||||
'enable': function.enable,
|
||||
'func': function.func.__name__,
|
||||
}
|
||||
for function in self.content_functions
|
||||
],
|
||||
'status': self.status.value,
|
||||
}
|
||||
|
||||
@@ -13,13 +13,16 @@ class PluginLoader(metaclass=abc.ABCMeta):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
plugins: list[context.RuntimeContainer]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.plugins = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load_plugins(self) -> list[context.RuntimeContainer]:
|
||||
async def load_plugins(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -20,7 +20,14 @@ class PluginLoader(loader.PluginLoader):
|
||||
|
||||
_current_container: context.RuntimeContainer = None
|
||||
|
||||
containers: list[context.RuntimeContainer] = []
|
||||
plugins: list[context.RuntimeContainer] = []
|
||||
|
||||
def __init__(self, ap):
|
||||
self.ap = ap
|
||||
self.plugins = []
|
||||
self._current_pkg_path = ''
|
||||
self._current_module_path = ''
|
||||
self._current_container = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化"""
|
||||
@@ -77,8 +84,10 @@ class PluginLoader(loader.PluginLoader):
|
||||
}
|
||||
|
||||
# 把 ctx.event 所有的属性都放到 args 里
|
||||
for k, v in ctx.event.dict().items():
|
||||
args[k] = v
|
||||
# for k, v in ctx.event.dict().items():
|
||||
# args[k] = v
|
||||
for attr_name in ctx.event.__dict__.keys():
|
||||
args[attr_name] = getattr(ctx.event, attr_name)
|
||||
|
||||
func(plugin, **args)
|
||||
|
||||
@@ -113,7 +122,6 @@ class PluginLoader(loader.PluginLoader):
|
||||
name=function_name,
|
||||
human_desc='',
|
||||
description=function_schema['description'],
|
||||
enable=True,
|
||||
parameters=function_schema['parameters'],
|
||||
func=handler,
|
||||
)
|
||||
@@ -153,7 +161,6 @@ class PluginLoader(loader.PluginLoader):
|
||||
name=function_name,
|
||||
human_desc='',
|
||||
description=function_schema['description'],
|
||||
enable=True,
|
||||
parameters=function_schema['parameters'],
|
||||
func=func,
|
||||
)
|
||||
@@ -189,15 +196,13 @@ class PluginLoader(loader.PluginLoader):
|
||||
importlib.import_module(module.__name__ + "." + item.name)
|
||||
|
||||
if self._current_container is not None:
|
||||
self.containers.append(self._current_container)
|
||||
self.plugins.append(self._current_container)
|
||||
self.ap.logger.debug(f'插件 {self._current_container} 已加载')
|
||||
except:
|
||||
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
|
||||
traceback.print_exc()
|
||||
|
||||
async def load_plugins(self) -> list[context.RuntimeContainer]:
|
||||
async def load_plugins(self):
|
||||
"""加载插件
|
||||
"""
|
||||
await self._walk_plugin_path(__import__("plugins", fromlist=[""]))
|
||||
|
||||
return self.containers
|
||||
|
||||
@@ -22,7 +22,22 @@ class PluginManager:
|
||||
|
||||
api_host: context.APIHost
|
||||
|
||||
plugins: list[context.RuntimeContainer]
|
||||
def plugins(
|
||||
self,
|
||||
enabled: bool=None,
|
||||
status: context.RuntimeContainerStatus=None,
|
||||
) -> list[context.RuntimeContainer]:
|
||||
"""获取插件列表
|
||||
"""
|
||||
plugins = self.loader.plugins
|
||||
|
||||
if enabled is not None:
|
||||
plugins = [plugin for plugin in plugins if plugin.enabled == enabled]
|
||||
|
||||
if status is not None:
|
||||
plugins = [plugin for plugin in plugins if plugin.status == status]
|
||||
|
||||
return plugins
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
@@ -30,7 +45,6 @@ class PluginManager:
|
||||
self.installer = github.GitHubRepoInstaller(ap)
|
||||
self.setting = setting.SettingManager(ap)
|
||||
self.api_host = context.APIHost(ap)
|
||||
self.plugins = []
|
||||
|
||||
async def initialize(self):
|
||||
await self.loader.initialize()
|
||||
@@ -41,27 +55,58 @@ class PluginManager:
|
||||
setattr(models, 'require_ver', self.api_host.require_ver)
|
||||
|
||||
async def load_plugins(self):
|
||||
self.plugins = await self.loader.load_plugins()
|
||||
await self.loader.load_plugins()
|
||||
|
||||
await self.setting.sync_setting(self.plugins)
|
||||
await self.setting.sync_setting(self.loader.plugins)
|
||||
|
||||
# 按优先级倒序
|
||||
self.plugins.sort(key=lambda x: x.priority, reverse=True)
|
||||
self.loader.plugins.sort(key=lambda x: x.priority, reverse=True)
|
||||
|
||||
self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugins}')
|
||||
self.ap.logger.debug(f'优先级排序后的插件列表 {self.loader.plugins}')
|
||||
|
||||
async def initialize_plugin(self, plugin: context.RuntimeContainer):
|
||||
self.ap.logger.debug(f'初始化插件 {plugin.plugin_name}')
|
||||
plugin.plugin_inst = plugin.plugin_class(self.api_host)
|
||||
plugin.plugin_inst.ap = self.ap
|
||||
plugin.plugin_inst.host = self.api_host
|
||||
await plugin.plugin_inst.initialize()
|
||||
plugin.status = context.RuntimeContainerStatus.INITIALIZED
|
||||
|
||||
async def initialize_plugins(self):
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins():
|
||||
if not plugin.enabled:
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 未启用,跳过初始化')
|
||||
continue
|
||||
try:
|
||||
plugin.plugin_inst = plugin.plugin_class(self.api_host)
|
||||
plugin.plugin_inst.ap = self.ap
|
||||
plugin.plugin_inst.host = self.api_host
|
||||
await plugin.plugin_inst.initialize()
|
||||
await self.initialize_plugin(plugin)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
|
||||
self.ap.logger.exception(e)
|
||||
continue
|
||||
|
||||
async def destroy_plugin(self, plugin: context.RuntimeContainer):
|
||||
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f'释放插件 {plugin.plugin_name}')
|
||||
plugin.plugin_inst.__del__()
|
||||
await plugin.plugin_inst.destroy()
|
||||
plugin.plugin_inst = None
|
||||
plugin.status = context.RuntimeContainerStatus.MOUNTED
|
||||
|
||||
async def destroy_plugins(self):
|
||||
for plugin in self.plugins():
|
||||
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 未初始化,跳过释放')
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.destroy_plugin(plugin)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 释放失败: {e}')
|
||||
self.ap.logger.exception(e)
|
||||
continue
|
||||
|
||||
async def install_plugin(
|
||||
self,
|
||||
plugin_source: str,
|
||||
@@ -80,6 +125,9 @@ class PluginManager:
|
||||
}
|
||||
)
|
||||
|
||||
task_context.trace('重载插件..', 'reload-plugin')
|
||||
await self.ap.reload(scope='plugin')
|
||||
|
||||
async def uninstall_plugin(
|
||||
self,
|
||||
plugin_name: str,
|
||||
@@ -87,10 +135,15 @@ class PluginManager:
|
||||
):
|
||||
"""卸载插件
|
||||
"""
|
||||
await self.installer.uninstall_plugin(plugin_name, task_context)
|
||||
|
||||
plugin_container = self.get_plugin_by_name(plugin_name)
|
||||
|
||||
if plugin_container is None:
|
||||
raise ValueError(f'插件 {plugin_name} 不存在')
|
||||
|
||||
await self.destroy_plugin(plugin_container)
|
||||
await self.installer.uninstall_plugin(plugin_name, task_context)
|
||||
|
||||
await self.ap.ctr_mgr.plugin.post_remove_record(
|
||||
{
|
||||
"name": plugin_name,
|
||||
@@ -100,6 +153,9 @@ class PluginManager:
|
||||
}
|
||||
)
|
||||
|
||||
task_context.trace('重载插件..', 'reload-plugin')
|
||||
await self.ap.reload(scope='plugin')
|
||||
|
||||
async def update_plugin(
|
||||
self,
|
||||
plugin_name: str,
|
||||
@@ -123,11 +179,13 @@ class PluginManager:
|
||||
new_version="HEAD"
|
||||
)
|
||||
|
||||
task_context.trace('重载插件..', 'reload-plugin')
|
||||
await self.ap.reload(scope='plugin')
|
||||
|
||||
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
|
||||
"""通过插件名获取插件
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins():
|
||||
if plugin.plugin_name == plugin_name:
|
||||
return plugin
|
||||
return None
|
||||
@@ -143,30 +201,32 @@ class PluginManager:
|
||||
|
||||
emitted_plugins: list[context.RuntimeContainer] = []
|
||||
|
||||
for plugin in self.plugins:
|
||||
if plugin.enabled:
|
||||
if event.__class__ in plugin.event_handlers:
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}')
|
||||
|
||||
is_prevented_default_before_call = ctx.is_prevented_default()
|
||||
for plugin in self.plugins(
|
||||
enabled=True,
|
||||
status=context.RuntimeContainerStatus.INITIALIZED
|
||||
):
|
||||
if event.__class__ in plugin.event_handlers:
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}')
|
||||
|
||||
is_prevented_default_before_call = ctx.is_prevented_default()
|
||||
|
||||
try:
|
||||
await plugin.event_handlers[event.__class__](
|
||||
plugin.plugin_inst,
|
||||
ctx
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}')
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
emitted_plugins.append(plugin)
|
||||
try:
|
||||
await plugin.event_handlers[event.__class__](
|
||||
plugin.plugin_inst,
|
||||
ctx
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}')
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
emitted_plugins.append(plugin)
|
||||
|
||||
if not is_prevented_default_before_call and ctx.is_prevented_default():
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
|
||||
if not is_prevented_default_before_call and ctx.is_prevented_default():
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
|
||||
|
||||
if ctx.is_prevented_postorder():
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
|
||||
break
|
||||
if ctx.is_prevented_postorder():
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
|
||||
break
|
||||
|
||||
for key in ctx.__return_value__.keys():
|
||||
if hasattr(ctx.event, key):
|
||||
@@ -191,16 +251,22 @@ class PluginManager:
|
||||
|
||||
return ctx
|
||||
|
||||
async def update_plugin_status(self, plugin_name: str, new_status: bool):
|
||||
async def update_plugin_switch(self, plugin_name: str, new_status: bool):
|
||||
if self.get_plugin_by_name(plugin_name) is not None:
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins():
|
||||
if plugin.plugin_name == plugin_name:
|
||||
plugin.enabled = new_status
|
||||
|
||||
for func in plugin.content_functions:
|
||||
func.enable = new_status
|
||||
if plugin.enabled == new_status:
|
||||
return False
|
||||
|
||||
await self.setting.dump_container_setting(self.plugins)
|
||||
# 初始化/释放插件
|
||||
if new_status:
|
||||
await self.initialize_plugin(plugin)
|
||||
else:
|
||||
await self.destroy_plugin(plugin)
|
||||
|
||||
plugin.enabled = new_status
|
||||
|
||||
await self.setting.dump_container_setting(self.loader.plugins)
|
||||
|
||||
break
|
||||
|
||||
@@ -214,11 +280,11 @@ class PluginManager:
|
||||
plugin_name = plugin.get('name')
|
||||
plugin_priority = plugin.get('priority')
|
||||
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.loader.plugins:
|
||||
if plugin.plugin_name == plugin_name:
|
||||
plugin.priority = plugin_priority
|
||||
break
|
||||
|
||||
self.plugins.sort(key=lambda x: x.priority, reverse=True)
|
||||
self.loader.plugins.sort(key=lambda x: x.priority, reverse=True)
|
||||
|
||||
await self.setting.dump_container_setting(self.plugins)
|
||||
await self.setting.dump_container_setting(self.loader.plugins)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...plugin import context as plugin_context
|
||||
|
||||
|
||||
class SessionManager:
|
||||
@@ -51,7 +52,10 @@ class SessionManager:
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
messages=[],
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
|
||||
use_funcs=await self.ap.tool_mgr.get_all_functions(),
|
||||
use_funcs=await self.ap.tool_mgr.get_all_functions(
|
||||
plugin_enabled=True,
|
||||
plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED,
|
||||
),
|
||||
)
|
||||
session.conversations.append(conversation)
|
||||
session.using_conversation = conversation
|
||||
|
||||
@@ -20,8 +20,6 @@ class LLMFunction(pydantic.BaseModel):
|
||||
description: str
|
||||
"""给LLM识别的函数描述"""
|
||||
|
||||
enable: typing.Optional[bool] = True
|
||||
|
||||
parameters: dict
|
||||
|
||||
func: typing.Callable
|
||||
|
||||
@@ -20,28 +20,25 @@ class ToolManager:
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def get_function(self, name: str) -> entities.LLMFunction:
|
||||
"""获取函数"""
|
||||
for function in await self.get_all_functions():
|
||||
if function.name == name:
|
||||
return function
|
||||
return None
|
||||
|
||||
async def get_function_and_plugin(
|
||||
self, name: str
|
||||
) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
|
||||
"""获取函数和插件"""
|
||||
for plugin in self.ap.plugin_mgr.plugins:
|
||||
"""获取函数和插件实例"""
|
||||
for plugin in self.ap.plugin_mgr.plugins(
|
||||
enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED
|
||||
):
|
||||
for function in plugin.content_functions:
|
||||
if function.name == name:
|
||||
return function, plugin.plugin_inst
|
||||
return None, None
|
||||
|
||||
async def get_all_functions(self) -> list[entities.LLMFunction]:
|
||||
async def get_all_functions(self, plugin_enabled: bool=None, plugin_status: plugin_context.RuntimeContainerStatus=None) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数"""
|
||||
all_functions: list[entities.LLMFunction] = []
|
||||
|
||||
for plugin in self.ap.plugin_mgr.plugins:
|
||||
for plugin in self.ap.plugin_mgr.plugins(
|
||||
enabled=plugin_enabled, status=plugin_status
|
||||
):
|
||||
all_functions.extend(plugin.content_functions)
|
||||
|
||||
return all_functions
|
||||
@@ -51,16 +48,15 @@ class ToolManager:
|
||||
tools = []
|
||||
|
||||
for function in use_funcs:
|
||||
if function.enable:
|
||||
function_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"parameters": function.parameters,
|
||||
},
|
||||
}
|
||||
tools.append(function_schema)
|
||||
function_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"parameters": function.parameters,
|
||||
},
|
||||
}
|
||||
tools.append(function_schema)
|
||||
|
||||
return tools
|
||||
|
||||
@@ -92,13 +88,12 @@ class ToolManager:
|
||||
tools = []
|
||||
|
||||
for function in use_funcs:
|
||||
if function.enable:
|
||||
function_schema = {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"input_schema": function.parameters,
|
||||
}
|
||||
tools.append(function_schema)
|
||||
function_schema = {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"input_schema": function.parameters,
|
||||
}
|
||||
tools.append(function_schema)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
9
pkg/utils/ip.py
Normal file
9
pkg/utils/ip.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import aiohttp
|
||||
|
||||
async def get_myip() -> str:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get("https://ip.useragentinfo.com/myip") as response:
|
||||
return await response.text()
|
||||
except Exception as e:
|
||||
return '0.0.0.0'
|
||||
@@ -3,7 +3,7 @@ openai>1.0.0
|
||||
anthropic
|
||||
colorlog~=6.6.0
|
||||
aiocqhttp
|
||||
qq-botpy
|
||||
qq-botpy-rc
|
||||
nakuru-project-idk
|
||||
Pillow
|
||||
tiktoken
|
||||
|
||||
@@ -1,13 +1,5 @@
|
||||
{
|
||||
"platform-adapters": [
|
||||
{
|
||||
"adapter": "yiri-mirai",
|
||||
"enable": false,
|
||||
"host": "127.0.0.1",
|
||||
"port": 8080,
|
||||
"verifyKey": "yirimirai",
|
||||
"qq": 123456789
|
||||
},
|
||||
{
|
||||
"adapter": "nakuru",
|
||||
"enable": false,
|
||||
|
||||
@@ -9,43 +9,6 @@
|
||||
"items": {
|
||||
"type": "object",
|
||||
"oneOf": [
|
||||
{
|
||||
"title": "YiriMirai 适配器",
|
||||
"description": "用于接入 Mirai",
|
||||
"properties": {
|
||||
"adapter": {
|
||||
"type": "string",
|
||||
"const": "yiri-mirai"
|
||||
},
|
||||
"enable": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "是否启用此适配器",
|
||||
"layout": {
|
||||
"comp": "switch",
|
||||
"props": {
|
||||
"color": "primary"
|
||||
}
|
||||
}
|
||||
},
|
||||
"host": {
|
||||
"type": "string",
|
||||
"default": "127.0.0.1"
|
||||
},
|
||||
"port": {
|
||||
"type": "integer",
|
||||
"default": 8080
|
||||
},
|
||||
"verifyKey": {
|
||||
"type": "string",
|
||||
"default": "yirimirai"
|
||||
},
|
||||
"qq": {
|
||||
"type": "integer",
|
||||
"default": 123456789
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"title": "Nakuru 适配器",
|
||||
"description": "用于接入 go-cqhttp",
|
||||
|
||||
@@ -67,6 +67,17 @@
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="reload('platform')">
|
||||
<v-list-item-title>
|
||||
重载消息平台
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="reload('plugin')">
|
||||
<v-list-item-title>
|
||||
重载插件
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-menu>
|
||||
</v-list-item>
|
||||
@@ -137,6 +148,30 @@ function openDocs() {
|
||||
window.open('https://docs.langbot.app', '_blank')
|
||||
}
|
||||
|
||||
const reloadScopeLabel = {
|
||||
'platform': "消息平台",
|
||||
'plugin': "插件"
|
||||
}
|
||||
|
||||
function reload(scope) {
|
||||
let label = reloadScopeLabel[scope]
|
||||
proxy.$axios.post('/system/reload',
|
||||
{ scope: scope },
|
||||
{ headers: { 'Content-Type': 'application/json' } }
|
||||
).then(response => {
|
||||
if (response.data.code === 0) {
|
||||
success(label+'已重载')
|
||||
|
||||
// 关闭菜单
|
||||
} else {
|
||||
error(label+'重载失败:' + response.data.message)
|
||||
}
|
||||
}).catch(err => {
|
||||
error(label+'重载失败:' + err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
const aboutDialogShow = ref(false)
|
||||
|
||||
function showAboutDialog() {
|
||||
@@ -162,10 +197,6 @@ function closeAboutDialog() {
|
||||
margin-left: -0.2rem;
|
||||
}
|
||||
|
||||
#logo-img {
|
||||
/* margin-left: -0.2rem; */
|
||||
}
|
||||
|
||||
#logo-list-item {
|
||||
margin-top: 0.5rem;
|
||||
margin-bottom: 1.5rem;
|
||||
|
||||
Reference in New Issue
Block a user