feat: 完成异步任务跟踪架构基础

This commit is contained in:
Junyan Qin
2024-11-01 22:41:26 +08:00
parent 2f05f5b456
commit 6d2a4c038d
16 changed files with 395 additions and 101 deletions
+16 -2
View File
@@ -5,7 +5,7 @@ import traceback
import quart
from .....core import app
from .....core import app, taskmgr
from .. import group
@@ -23,12 +23,26 @@ class PluginsRouterGroup(group.RouterGroup):
'plugins': plugins_data
})
@self.route('/toggle/<author>/<plugin_name>', methods=['PUT'])
@self.route('/<author>/<plugin_name>/toggle', methods=['PUT'])
async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json
target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_status(plugin_name, target_enabled)
return self.success()
@self.route('/<author>/<plugin_name>/update', methods=['POST'])
async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
name=f"plugin-update-{plugin_name}",
label=f"更新插件 {plugin_name}",
context=ctx
)
return self.success(data={
'task_id': wrapper.id
})
@self.route('/reorder', methods=['PUT'])
async def _() -> str:
+22 -1
View File
@@ -1,6 +1,7 @@
import quart
import asyncio
from .....core import app
from .....core import app, taskmgr
from .. import group
from .....utils import constants
@@ -17,3 +18,23 @@ class SystemRouterGroup(group.RouterGroup):
"debug": constants.debug_mode
}
)
@self.route('/tasks', methods=['GET'])
async def _() -> str:
task_type = quart.request.args.get("type")
if task_type == '':
task_type = None
return self.success(
data=self.ap.task_mgr.get_tasks_dict(task_type)
)
@self.route('/tasks/<task_id>', methods=['GET'])
async def _(task_id: str) -> str:
task = self.ap.task_mgr.get_task_by_id(int(task_id))
if task is None:
return self.http_status(404, 404, "Task not found")
return self.success(data=task.to_dict())
+15 -21
View File
@@ -19,39 +19,33 @@ class HTTPController:
def __init__(self, ap: app.Application) -> None:
self.ap = ap
self.quart_app = quart.Quart(__name__)
quart_cors.cors(self.quart_app, allow_origin='*')
quart_cors.cors(self.quart_app, allow_origin="*")
async def initialize(self) -> None:
await self.register_routes()
async def run(self) -> None:
if self.ap.system_cfg.data['http-api']['enable']:
if self.ap.system_cfg.data["http-api"]["enable"]:
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
# task = asyncio.create_task(self.quart_app.run_task(
# host=self.ap.system_cfg.data['http-api']['host'],
# port=self.ap.system_cfg.data['http-api']['port'],
# shutdown_trigger=shutdown_trigger_placeholder
# ))
# self.ap.asyncio_tasks.append(task)
self.ap.task_mgr.create_task(self.quart_app.run_task(
host=self.ap.system_cfg.data['http-api']['host'],
port=self.ap.system_cfg.data['http-api']['port'],
shutdown_trigger=shutdown_trigger_placeholder
))
self.ap.task_mgr.create_task(
self.quart_app.run_task(
host=self.ap.system_cfg.data["http-api"]["host"],
port=self.ap.system_cfg.data["http-api"]["port"],
shutdown_trigger=shutdown_trigger_placeholder,
),
name="http-api-quart",
)
async def register_routes(self) -> None:
@self.quart_app.route('/healthz')
@self.quart_app.route("/healthz")
async def healthz():
return {
"code": 0,
"msg": "ok"
}
return {"code": 0, "msg": "ok"}
for g in group.preregistered_groups:
ginst = g(self.ap, self.quart_app)
await ginst.initialize()
+19 -26
View File
@@ -14,6 +14,7 @@ from ...core import app
class APIGroup(metaclass=abc.ABCMeta):
"""API 组抽象类"""
_basic_info: dict = None
_runtime_info: dict = None
@@ -32,33 +33,28 @@ class APIGroup(metaclass=abc.ABCMeta):
data: dict = None,
params: dict = None,
headers: dict = {},
**kwargs
**kwargs,
):
"""
执行请求
"""
self._runtime_info['account_id'] = "-1"
self._runtime_info["account_id"] = "-1"
url = self.prefix + path
data = json.dumps(data)
headers['Content-Type'] = 'application/json'
headers["Content-Type"] = "application/json"
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method,
url,
data=data,
params=params,
headers=headers,
**kwargs
method, url, data=data, params=params, headers=headers, **kwargs
) as resp:
self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.text())
except Exception as e:
self.ap.logger.debug(f'上报失败: {e}')
self.ap.logger.debug(f"上报失败: {e}")
async def do(
self,
method: str,
@@ -66,32 +62,29 @@ class APIGroup(metaclass=abc.ABCMeta):
data: dict = None,
params: dict = None,
headers: dict = {},
**kwargs
**kwargs,
) -> asyncio.Task:
"""执行请求"""
# task = asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
# self.ap.asyncio_tasks.append(task)
return self.ap.task_mgr.create_task(
self._do(method, path, data, params, headers, **kwargs),
kind="telemetry-operation",
name=f"{method} {path}",
).task
return self.ap.task_mgr.create_task(self._do(method, path, data, params, headers, **kwargs)).task
def gen_rid(
self
):
def gen_rid(self):
"""生成一个请求 ID"""
return str(uuid.uuid4())
def basic_info(
self
):
def basic_info(self):
"""获取基本信息"""
basic_info = APIGroup._basic_info.copy()
basic_info['rid'] = self.gen_rid()
basic_info["rid"] = self.gen_rid()
return basic_info
def runtime_info(
self
):
def runtime_info(self):
"""获取运行时信息"""
return APIGroup._runtime_info
+4 -11
View File
@@ -114,17 +114,10 @@ class Application:
while True:
await asyncio.sleep(1)
# tasks = [
# asyncio.create_task(self.platform_mgr.run()), # 消息平台
# asyncio.create_task(self.ctrl.run()), # 消息处理循环
# asyncio.create_task(self.http_ctrl.run()), # http 接口服务
# asyncio.create_task(never_ending())
# ]
# self.asyncio_tasks.extend(tasks)
self.task_mgr.create_task(self.platform_mgr.run())
self.task_mgr.create_task(self.ctrl.run())
self.task_mgr.create_task(self.http_ctrl.run())
self.task_mgr.create_task(never_ending())
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager")
self.task_mgr.create_task(self.ctrl.run(), name="query-controller")
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller")
self.task_mgr.create_task(never_ending(), name="never-ending-task")
await self.task_mgr.wait_all()
except asyncio.CancelledError:
+132 -6
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import typing
import datetime
from . import app
@@ -16,22 +17,68 @@ class TaskContext:
"""记录日志"""
def __init__(self):
self.current_action = ""
self.current_action = "default"
self.log = ""
def log(self, msg: str):
def _log(self, msg: str):
self.log += msg + "\n"
def set_current_action(self, action: str):
self.current_action = action
def trace(
self,
msg: str,
action: str = None,
):
if action is not None:
self.set_current_action(action)
self._log(
f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | {self.current_action} | {msg}"
)
def to_dict(self) -> dict:
return {"current_action": self.current_action, "log": self.log}
@staticmethod
def new() -> TaskContext:
return TaskContext()
@staticmethod
def placeholder() -> TaskContext:
global placeholder_context
if placeholder_context is None:
placeholder_context = TaskContext()
return placeholder_context
placeholder_context: TaskContext | None = None
class TaskWrapper:
"""任务包装器"""
_id_index: int = 0
"""任务ID索引"""
id: int
"""任务ID"""
task_type: str = "system" # 任务类型: system 或 user
"""任务类型"""
kind: str = "system_task"
"""任务种类"""
name: str = ""
"""任务唯一名称"""
label: str = ""
"""任务显示名称"""
task_context: TaskContext
"""任务上下文"""
@@ -41,17 +88,61 @@ class TaskWrapper:
ap: app.Application
"""应用实例"""
def __init__(self, ap: app.Application, coro: typing.Coroutine, task_type: str = "system", context: TaskContext = None):
def __init__(
self,
ap: app.Application,
coro: typing.Coroutine,
task_type: str = "system",
kind: str = "system_task",
name: str = "",
label: str = "",
context: TaskContext = None,
):
self.id = TaskWrapper._id_index
TaskWrapper._id_index += 1
self.ap = ap
self.task_context = context or TaskContext()
self.task = self.ap.event_loop.create_task(coro)
self.task_type = task_type
self.kind = kind
self.name = name
self.label = label if label != "" else name
self.task.set_name(name)
def assume_exception(self):
try:
return self.task.exception()
except:
return None
def assume_result(self):
try:
return self.task.result()
except:
return None
def to_dict(self) -> dict:
return {
"id": self.id,
"task_type": self.task_type,
"kind": self.kind,
"name": self.name,
"label": self.label,
"task_context": self.task_context.to_dict(),
"runtime": {
"done": self.task.done(),
"state": self.task._state,
"exception": self.assume_exception(),
"result": self.assume_result(),
},
}
class AsyncTaskManager:
"""保存app中的所有异步任务
包含系统级的和用户级(插件安装、更新等由用户直接发起的)的"""
ap: app.Application
tasks: list[TaskWrapper]
@@ -61,13 +152,48 @@ class AsyncTaskManager:
self.ap = ap
self.tasks = []
def create_task(self, coro: typing.Coroutine, task_type: str = "system", context: TaskContext = None) -> TaskWrapper:
wrapper = TaskWrapper(self.ap, coro, task_type, context)
def create_task(
self,
coro: typing.Coroutine,
task_type: str = "system",
kind: str = "system-task",
name: str = "",
label: str = "",
context: TaskContext = None,
) -> TaskWrapper:
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context)
self.tasks.append(wrapper)
return wrapper
def create_user_task(
self,
coro: typing.Coroutine,
kind: str = "user-task",
name: str = "",
label: str = "",
context: TaskContext = None,
) -> TaskWrapper:
return self.create_task(coro, "user", kind, name, label, context)
async def wait_all(self):
await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True)
def get_all_tasks(self) -> list[TaskWrapper]:
return self.tasks
def get_tasks_dict(
self,
type: str = None,
) -> dict:
return {
"tasks": [
t.to_dict() for t in self.tasks if type is None or t.task_type == type
],
"id_index": TaskWrapper._id_index,
}
def get_task_by_id(self, id: int) -> TaskWrapper | None:
for t in self.tasks:
if t.id == id:
return t
return None
+5 -1
View File
@@ -62,7 +62,11 @@ class Controller:
# task = asyncio.create_task(_process_query(selected_query))
# self.ap.asyncio_tasks.append(task)
self.ap.task_mgr.create_task(_process_query(selected_query))
self.ap.task_mgr.create_task(
_process_query(selected_query),
kind="query",
name=f"query-{selected_query.query_id}",
)
except Exception as e:
# traceback.print_exc()
+5 -1
View File
@@ -186,7 +186,11 @@ class PlatformManager:
for task in tasks:
# async_task = asyncio.create_task(task)
# self.ap.asyncio_tasks.append(async_task)
self.ap.task_mgr.create_task(task)
self.ap.task_mgr.create_task(
task,
kind="platform-adapter",
name=f"platform-adapter-{adapter.name}",
)
except Exception as e:
self.ap.logger.error('平台适配器运行出错: ' + str(e))
+2 -1
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing
import abc
from ..core import app
from ..core import app, taskmgr
class PluginInstaller(metaclass=abc.ABCMeta):
@@ -40,6 +40,7 @@ class PluginInstaller(metaclass=abc.ABCMeta):
self,
plugin_name: str,
plugin_source: str=None,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""更新插件
"""
+14 -1
View File
@@ -9,6 +9,7 @@ import requests
from .. import installer, errors
from ...utils import pkgmgr
from ...core import taskmgr
class GitHubRepoInstaller(installer.PluginInstaller):
@@ -94,13 +95,20 @@ class GitHubRepoInstaller(installer.PluginInstaller):
async def install_plugin(
self,
plugin_source: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""安装插件
"""
task_context.trace("下载插件源码...", "install-plugin")
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/")
task_context.trace("安装插件依赖...", "install-plugin")
await self.install_requirements("plugins/" + repo_label)
task_context.trace("完成.", "install-plugin")
await self.ap.plugin_mgr.setting.record_installed_plugin_source(
"plugins/"+repo_label+'/', plugin_source
)
@@ -122,9 +130,12 @@ class GitHubRepoInstaller(installer.PluginInstaller):
self,
plugin_name: str,
plugin_source: str=None,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""更新插件
"""
task_context.trace("更新插件...", "update-plugin")
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is None:
@@ -133,7 +144,9 @@ class GitHubRepoInstaller(installer.PluginInstaller):
if plugin_container.plugin_source:
plugin_source = plugin_container.plugin_source
await self.install_plugin(plugin_source)
task_context.trace("转交安装任务.", "update-plugin")
await self.install_plugin(plugin_source, task_context)
else:
raise errors.PluginInstallerError('插件无源码信息,无法更新')
+4 -2
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing
import traceback
from ..core import app
from ..core import app, taskmgr
from . import context, loader, events, installer, setting, models
from .loaders import classic
from .installers import github
@@ -102,10 +102,11 @@ class PluginManager:
self,
plugin_name: str,
plugin_source: str=None,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
):
"""更新插件
"""
await self.installer.update_plugin(plugin_name, plugin_source)
await self.installer.update_plugin(plugin_name, plugin_source, task_context)
plugin_container = self.get_plugin_by_name(plugin_name)
@@ -120,6 +121,7 @@ class PluginManager:
new_version="HEAD"
)
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
"""通过插件名获取插件
"""