feat: 消息平台热重载

This commit is contained in:
Junyan Qin
2024-11-16 12:40:57 +08:00
parent 3239c9ec3f
commit bb219889e5
14 changed files with 168 additions and 36 deletions
+49 -10
View File
@@ -4,6 +4,7 @@ import logging
import asyncio
import threading
import traceback
import enum
from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
@@ -21,8 +22,9 @@ from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
from ..utils import logcache
from ..utils import logcache, ip
from . import taskmgr
from . import entities as core_entities
class Application:
@@ -114,11 +116,12 @@ class Application:
while True:
await asyncio.sleep(1)
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager")
self.task_mgr.create_task(self.ctrl.run(), name="query-controller")
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller")
self.task_mgr.create_task(never_ending(), name="never-ending-task")
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION])
await self.print_web_access_info()
await self.task_mgr.wait_all()
except asyncio.CancelledError:
pass
@@ -126,9 +129,45 @@ class Application:
self.logger.error(f"应用运行致命异常: {e}")
self.logger.debug(f"Traceback: {traceback.format_exc()}")
async def scoped_shutdown(self, scopes: list[str]):
pass
async def print_web_access_info(self):
"""打印访问 webui 的提示"""
import socket
async def shutdown(self):
for task in self.task_mgr.tasks:
task.cancel()
host_ip = socket.gethostbyname(socket.gethostname())
public_ip = await ip.get_myip()
port = self.system_cfg.data['http-api']['port']
tips = f"""
=======================================
✨ 您可通过以下方式访问管理面板:
🏠 本地地址:http://{host_ip}:{port}/
🌐 公网地址:http://{public_ip}:{port}/
📌 如果您在容器中运行此程序,请确保容器的 {port} 端口已对外暴露
🔗 若要使用公网地址访问,请阅读以下须知
1. 公网地址仅供参考,请以您的主机公网 IP 为准;
2. 要使用公网地址访问,请确保您的主机具有公网 IP,并且系统防火墙已放行 {port} 端口;
=======================================
""".strip()
for line in tips.split("\n"):
self.logger.info(line)
async def reload(
self,
scope: core_entities.LifecycleControlScope,
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info("执行热重载 scope="+scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
case _:
pass
+1 -1
View File
@@ -6,7 +6,7 @@ required_deps = {
"anthropic": "anthropic",
"colorlog": "colorlog",
"aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy",
"botpy": "qq-botpy-rc",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"tiktoken": "tiktoken",
+8
View File
@@ -17,6 +17,14 @@ from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities
class LifecycleControlScope(enum.Enum):
APPLICATION = "application"
PLATFORM = "platform"
PLUGIN = "plugin"
class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""
+19 -3
View File
@@ -6,6 +6,7 @@ import datetime
import traceback
from . import app
from . import entities as core_entities
class TaskContext:
@@ -71,7 +72,7 @@ class TaskWrapper:
task_type: str = "system" # 任务类型: system 或 user
"""任务类型"""
kind: str = "system_task"
kind: str = "system_task" # 由发起者确定任务种类,通常同质化的任务种类相同
"""任务种类"""
name: str = ""
@@ -92,6 +93,9 @@ class TaskWrapper:
ap: app.Application
"""应用实例"""
scopes: list[core_entities.LifecycleControlScope]
"""任务所属生命周期控制范围"""
def __init__(
self,
ap: app.Application,
@@ -101,6 +105,7 @@ class TaskWrapper:
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
):
self.id = TaskWrapper._id_index
TaskWrapper._id_index += 1
@@ -112,6 +117,7 @@ class TaskWrapper:
self.name = name
self.label = label if label != "" else name
self.task.set_name(name)
self.scopes = scopes
def assume_exception(self):
try:
@@ -145,6 +151,7 @@ class TaskWrapper:
"kind": self.kind,
"name": self.name,
"label": self.label,
"scopes": [scope.value for scope in self.scopes],
"task_context": self.task_context.to_dict(),
"runtime": {
"done": self.task.done(),
@@ -180,8 +187,9 @@ class AsyncTaskManager:
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context)
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
self.tasks.append(wrapper)
return wrapper
@@ -192,8 +200,9 @@ class AsyncTaskManager:
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
return self.create_task(coro, "user", kind, name, label, context)
return self.create_task(coro, "user", kind, name, label, context, scopes)
async def wait_all(self):
await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True)
@@ -217,3 +226,10 @@ class AsyncTaskManager:
if t.id == id:
return t
return None
def cancel_by_scope(self, scope: core_entities.LifecycleControlScope):
for wrapper in self.tasks:
if not wrapper.task.done() and scope in wrapper.scopes:
wrapper.task.cancel()