mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 16:04:21 +00:00
feat: 完成异步任务跟踪架构基础
This commit is contained in:
@@ -5,7 +5,7 @@ import traceback
|
||||
|
||||
import quart
|
||||
|
||||
from .....core import app
|
||||
from .....core import app, taskmgr
|
||||
from .. import group
|
||||
|
||||
|
||||
@@ -23,12 +23,26 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
'plugins': plugins_data
|
||||
})
|
||||
|
||||
@self.route('/toggle/<author>/<plugin_name>', methods=['PUT'])
|
||||
@self.route('/<author>/<plugin_name>/toggle', methods=['PUT'])
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
data = await quart.request.json
|
||||
target_enabled = data.get('target_enabled')
|
||||
await self.ap.plugin_mgr.update_plugin_status(plugin_name, target_enabled)
|
||||
return self.success()
|
||||
|
||||
@self.route('/<author>/<plugin_name>/update', methods=['POST'])
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
|
||||
kind="plugin-operation",
|
||||
name=f"plugin-update-{plugin_name}",
|
||||
label=f"更新插件 {plugin_name}",
|
||||
context=ctx
|
||||
)
|
||||
return self.success(data={
|
||||
'task_id': wrapper.id
|
||||
})
|
||||
|
||||
@self.route('/reorder', methods=['PUT'])
|
||||
async def _() -> str:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import quart
|
||||
import asyncio
|
||||
|
||||
from .....core import app
|
||||
from .....core import app, taskmgr
|
||||
from .. import group
|
||||
from .....utils import constants
|
||||
|
||||
@@ -17,3 +18,23 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
"debug": constants.debug_mode
|
||||
}
|
||||
)
|
||||
|
||||
@self.route('/tasks', methods=['GET'])
|
||||
async def _() -> str:
|
||||
task_type = quart.request.args.get("type")
|
||||
|
||||
if task_type == '':
|
||||
task_type = None
|
||||
|
||||
return self.success(
|
||||
data=self.ap.task_mgr.get_tasks_dict(task_type)
|
||||
)
|
||||
|
||||
@self.route('/tasks/<task_id>', methods=['GET'])
|
||||
async def _(task_id: str) -> str:
|
||||
task = self.ap.task_mgr.get_task_by_id(int(task_id))
|
||||
|
||||
if task is None:
|
||||
return self.http_status(404, 404, "Task not found")
|
||||
|
||||
return self.success(data=task.to_dict())
|
||||
|
||||
@@ -19,39 +19,33 @@ class HTTPController:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
self.quart_app = quart.Quart(__name__)
|
||||
quart_cors.cors(self.quart_app, allow_origin='*')
|
||||
quart_cors.cors(self.quart_app, allow_origin="*")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self.register_routes()
|
||||
|
||||
async def run(self) -> None:
|
||||
if self.ap.system_cfg.data['http-api']['enable']:
|
||||
if self.ap.system_cfg.data["http-api"]["enable"]:
|
||||
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# task = asyncio.create_task(self.quart_app.run_task(
|
||||
# host=self.ap.system_cfg.data['http-api']['host'],
|
||||
# port=self.ap.system_cfg.data['http-api']['port'],
|
||||
# shutdown_trigger=shutdown_trigger_placeholder
|
||||
# ))
|
||||
# self.ap.asyncio_tasks.append(task)
|
||||
self.ap.task_mgr.create_task(self.quart_app.run_task(
|
||||
host=self.ap.system_cfg.data['http-api']['host'],
|
||||
port=self.ap.system_cfg.data['http-api']['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder
|
||||
))
|
||||
|
||||
self.ap.task_mgr.create_task(
|
||||
self.quart_app.run_task(
|
||||
host=self.ap.system_cfg.data["http-api"]["host"],
|
||||
port=self.ap.system_cfg.data["http-api"]["port"],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
),
|
||||
name="http-api-quart",
|
||||
)
|
||||
|
||||
async def register_routes(self) -> None:
|
||||
|
||||
@self.quart_app.route('/healthz')
|
||||
|
||||
@self.quart_app.route("/healthz")
|
||||
async def healthz():
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "ok"
|
||||
}
|
||||
|
||||
return {"code": 0, "msg": "ok"}
|
||||
|
||||
for g in group.preregistered_groups:
|
||||
ginst = g(self.ap, self.quart_app)
|
||||
await ginst.initialize()
|
||||
|
||||
@@ -14,6 +14,7 @@ from ...core import app
|
||||
|
||||
class APIGroup(metaclass=abc.ABCMeta):
|
||||
"""API 组抽象类"""
|
||||
|
||||
_basic_info: dict = None
|
||||
_runtime_info: dict = None
|
||||
|
||||
@@ -32,33 +33,28 @@ class APIGroup(metaclass=abc.ABCMeta):
|
||||
data: dict = None,
|
||||
params: dict = None,
|
||||
headers: dict = {},
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
执行请求
|
||||
"""
|
||||
self._runtime_info['account_id'] = "-1"
|
||||
|
||||
self._runtime_info["account_id"] = "-1"
|
||||
|
||||
url = self.prefix + path
|
||||
data = json.dumps(data)
|
||||
headers['Content-Type'] = 'application/json'
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
params=params,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
method, url, data=data, params=params, headers=headers, **kwargs
|
||||
) as resp:
|
||||
self.ap.logger.debug("data: %s", data)
|
||||
self.ap.logger.debug("ret: %s", await resp.text())
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.debug(f'上报失败: {e}')
|
||||
|
||||
self.ap.logger.debug(f"上报失败: {e}")
|
||||
|
||||
async def do(
|
||||
self,
|
||||
method: str,
|
||||
@@ -66,32 +62,29 @@ class APIGroup(metaclass=abc.ABCMeta):
|
||||
data: dict = None,
|
||||
params: dict = None,
|
||||
headers: dict = {},
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> asyncio.Task:
|
||||
"""执行请求"""
|
||||
# task = asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
|
||||
|
||||
# self.ap.asyncio_tasks.append(task)
|
||||
|
||||
return self.ap.task_mgr.create_task(
|
||||
self._do(method, path, data, params, headers, **kwargs),
|
||||
kind="telemetry-operation",
|
||||
name=f"{method} {path}",
|
||||
).task
|
||||
|
||||
return self.ap.task_mgr.create_task(self._do(method, path, data, params, headers, **kwargs)).task
|
||||
|
||||
def gen_rid(
|
||||
self
|
||||
):
|
||||
def gen_rid(self):
|
||||
"""生成一个请求 ID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def basic_info(
|
||||
self
|
||||
):
|
||||
def basic_info(self):
|
||||
"""获取基本信息"""
|
||||
basic_info = APIGroup._basic_info.copy()
|
||||
basic_info['rid'] = self.gen_rid()
|
||||
basic_info["rid"] = self.gen_rid()
|
||||
return basic_info
|
||||
|
||||
def runtime_info(
|
||||
self
|
||||
):
|
||||
|
||||
def runtime_info(self):
|
||||
"""获取运行时信息"""
|
||||
return APIGroup._runtime_info
|
||||
|
||||
+4
-11
@@ -114,17 +114,10 @@ class Application:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# tasks = [
|
||||
# asyncio.create_task(self.platform_mgr.run()), # 消息平台
|
||||
# asyncio.create_task(self.ctrl.run()), # 消息处理循环
|
||||
# asyncio.create_task(self.http_ctrl.run()), # http 接口服务
|
||||
# asyncio.create_task(never_ending())
|
||||
# ]
|
||||
# self.asyncio_tasks.extend(tasks)
|
||||
self.task_mgr.create_task(self.platform_mgr.run())
|
||||
self.task_mgr.create_task(self.ctrl.run())
|
||||
self.task_mgr.create_task(self.http_ctrl.run())
|
||||
self.task_mgr.create_task(never_ending())
|
||||
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager")
|
||||
self.task_mgr.create_task(self.ctrl.run(), name="query-controller")
|
||||
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller")
|
||||
self.task_mgr.create_task(never_ending(), name="never-ending-task")
|
||||
|
||||
await self.task_mgr.wait_all()
|
||||
except asyncio.CancelledError:
|
||||
|
||||
+132
-6
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import datetime
|
||||
|
||||
from . import app
|
||||
|
||||
@@ -16,22 +17,68 @@ class TaskContext:
|
||||
"""记录日志"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_action = ""
|
||||
self.current_action = "default"
|
||||
self.log = ""
|
||||
|
||||
def log(self, msg: str):
|
||||
def _log(self, msg: str):
|
||||
self.log += msg + "\n"
|
||||
|
||||
def set_current_action(self, action: str):
|
||||
self.current_action = action
|
||||
|
||||
def trace(
|
||||
self,
|
||||
msg: str,
|
||||
action: str = None,
|
||||
):
|
||||
if action is not None:
|
||||
self.set_current_action(action)
|
||||
|
||||
self._log(
|
||||
f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | {self.current_action} | {msg}"
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"current_action": self.current_action, "log": self.log}
|
||||
|
||||
@staticmethod
|
||||
def new() -> TaskContext:
|
||||
return TaskContext()
|
||||
|
||||
@staticmethod
|
||||
def placeholder() -> TaskContext:
|
||||
global placeholder_context
|
||||
|
||||
if placeholder_context is None:
|
||||
placeholder_context = TaskContext()
|
||||
|
||||
return placeholder_context
|
||||
|
||||
|
||||
placeholder_context: TaskContext | None = None
|
||||
|
||||
|
||||
class TaskWrapper:
|
||||
"""任务包装器"""
|
||||
|
||||
_id_index: int = 0
|
||||
"""任务ID索引"""
|
||||
|
||||
id: int
|
||||
"""任务ID"""
|
||||
|
||||
task_type: str = "system" # 任务类型: system 或 user
|
||||
"""任务类型"""
|
||||
|
||||
kind: str = "system_task"
|
||||
"""任务种类"""
|
||||
|
||||
name: str = ""
|
||||
"""任务唯一名称"""
|
||||
|
||||
label: str = ""
|
||||
"""任务显示名称"""
|
||||
|
||||
task_context: TaskContext
|
||||
"""任务上下文"""
|
||||
|
||||
@@ -41,17 +88,61 @@ class TaskWrapper:
|
||||
ap: app.Application
|
||||
"""应用实例"""
|
||||
|
||||
def __init__(self, ap: app.Application, coro: typing.Coroutine, task_type: str = "system", context: TaskContext = None):
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
coro: typing.Coroutine,
|
||||
task_type: str = "system",
|
||||
kind: str = "system_task",
|
||||
name: str = "",
|
||||
label: str = "",
|
||||
context: TaskContext = None,
|
||||
):
|
||||
self.id = TaskWrapper._id_index
|
||||
TaskWrapper._id_index += 1
|
||||
self.ap = ap
|
||||
self.task_context = context or TaskContext()
|
||||
self.task = self.ap.event_loop.create_task(coro)
|
||||
self.task_type = task_type
|
||||
self.kind = kind
|
||||
self.name = name
|
||||
self.label = label if label != "" else name
|
||||
self.task.set_name(name)
|
||||
|
||||
def assume_exception(self):
|
||||
try:
|
||||
return self.task.exception()
|
||||
except:
|
||||
return None
|
||||
|
||||
def assume_result(self):
|
||||
try:
|
||||
return self.task.result()
|
||||
except:
|
||||
return None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"task_type": self.task_type,
|
||||
"kind": self.kind,
|
||||
"name": self.name,
|
||||
"label": self.label,
|
||||
"task_context": self.task_context.to_dict(),
|
||||
"runtime": {
|
||||
"done": self.task.done(),
|
||||
"state": self.task._state,
|
||||
"exception": self.assume_exception(),
|
||||
"result": self.assume_result(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AsyncTaskManager:
|
||||
"""保存app中的所有异步任务
|
||||
包含系统级的和用户级(插件安装、更新等由用户直接发起的)的"""
|
||||
|
||||
|
||||
ap: app.Application
|
||||
|
||||
tasks: list[TaskWrapper]
|
||||
@@ -61,13 +152,48 @@ class AsyncTaskManager:
|
||||
self.ap = ap
|
||||
self.tasks = []
|
||||
|
||||
def create_task(self, coro: typing.Coroutine, task_type: str = "system", context: TaskContext = None) -> TaskWrapper:
|
||||
wrapper = TaskWrapper(self.ap, coro, task_type, context)
|
||||
def create_task(
|
||||
self,
|
||||
coro: typing.Coroutine,
|
||||
task_type: str = "system",
|
||||
kind: str = "system-task",
|
||||
name: str = "",
|
||||
label: str = "",
|
||||
context: TaskContext = None,
|
||||
) -> TaskWrapper:
|
||||
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context)
|
||||
self.tasks.append(wrapper)
|
||||
return wrapper
|
||||
|
||||
def create_user_task(
|
||||
self,
|
||||
coro: typing.Coroutine,
|
||||
kind: str = "user-task",
|
||||
name: str = "",
|
||||
label: str = "",
|
||||
context: TaskContext = None,
|
||||
) -> TaskWrapper:
|
||||
return self.create_task(coro, "user", kind, name, label, context)
|
||||
|
||||
async def wait_all(self):
|
||||
await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True)
|
||||
|
||||
def get_all_tasks(self) -> list[TaskWrapper]:
|
||||
return self.tasks
|
||||
|
||||
def get_tasks_dict(
|
||||
self,
|
||||
type: str = None,
|
||||
) -> dict:
|
||||
return {
|
||||
"tasks": [
|
||||
t.to_dict() for t in self.tasks if type is None or t.task_type == type
|
||||
],
|
||||
"id_index": TaskWrapper._id_index,
|
||||
}
|
||||
|
||||
def get_task_by_id(self, id: int) -> TaskWrapper | None:
|
||||
for t in self.tasks:
|
||||
if t.id == id:
|
||||
return t
|
||||
return None
|
||||
|
||||
@@ -62,7 +62,11 @@ class Controller:
|
||||
|
||||
# task = asyncio.create_task(_process_query(selected_query))
|
||||
# self.ap.asyncio_tasks.append(task)
|
||||
self.ap.task_mgr.create_task(_process_query(selected_query))
|
||||
self.ap.task_mgr.create_task(
|
||||
_process_query(selected_query),
|
||||
kind="query",
|
||||
name=f"query-{selected_query.query_id}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# traceback.print_exc()
|
||||
|
||||
@@ -186,7 +186,11 @@ class PlatformManager:
|
||||
for task in tasks:
|
||||
# async_task = asyncio.create_task(task)
|
||||
# self.ap.asyncio_tasks.append(async_task)
|
||||
self.ap.task_mgr.create_task(task)
|
||||
self.ap.task_mgr.create_task(
|
||||
task,
|
||||
kind="platform-adapter",
|
||||
name=f"platform-adapter-{adapter.name}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error('平台适配器运行出错: ' + str(e))
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ..core import app
|
||||
from ..core import app, taskmgr
|
||||
|
||||
|
||||
class PluginInstaller(metaclass=abc.ABCMeta):
|
||||
@@ -40,6 +40,7 @@ class PluginInstaller(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
plugin_name: str,
|
||||
plugin_source: str=None,
|
||||
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
|
||||
):
|
||||
"""更新插件
|
||||
"""
|
||||
|
||||
@@ -9,6 +9,7 @@ import requests
|
||||
|
||||
from .. import installer, errors
|
||||
from ...utils import pkgmgr
|
||||
from ...core import taskmgr
|
||||
|
||||
|
||||
class GitHubRepoInstaller(installer.PluginInstaller):
|
||||
@@ -94,13 +95,20 @@ class GitHubRepoInstaller(installer.PluginInstaller):
|
||||
async def install_plugin(
|
||||
self,
|
||||
plugin_source: str,
|
||||
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
|
||||
):
|
||||
"""安装插件
|
||||
"""
|
||||
task_context.trace("下载插件源码...", "install-plugin")
|
||||
|
||||
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/")
|
||||
|
||||
task_context.trace("安装插件依赖...", "install-plugin")
|
||||
|
||||
await self.install_requirements("plugins/" + repo_label)
|
||||
|
||||
task_context.trace("完成.", "install-plugin")
|
||||
|
||||
await self.ap.plugin_mgr.setting.record_installed_plugin_source(
|
||||
"plugins/"+repo_label+'/', plugin_source
|
||||
)
|
||||
@@ -122,9 +130,12 @@ class GitHubRepoInstaller(installer.PluginInstaller):
|
||||
self,
|
||||
plugin_name: str,
|
||||
plugin_source: str=None,
|
||||
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
|
||||
):
|
||||
"""更新插件
|
||||
"""
|
||||
task_context.trace("更新插件...", "update-plugin")
|
||||
|
||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||
|
||||
if plugin_container is None:
|
||||
@@ -133,7 +144,9 @@ class GitHubRepoInstaller(installer.PluginInstaller):
|
||||
if plugin_container.plugin_source:
|
||||
plugin_source = plugin_container.plugin_source
|
||||
|
||||
await self.install_plugin(plugin_source)
|
||||
task_context.trace("转交安装任务.", "update-plugin")
|
||||
|
||||
await self.install_plugin(plugin_source, task_context)
|
||||
|
||||
else:
|
||||
raise errors.PluginInstallerError('插件无源码信息,无法更新')
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from ..core import app
|
||||
from ..core import app, taskmgr
|
||||
from . import context, loader, events, installer, setting, models
|
||||
from .loaders import classic
|
||||
from .installers import github
|
||||
@@ -102,10 +102,11 @@ class PluginManager:
|
||||
self,
|
||||
plugin_name: str,
|
||||
plugin_source: str=None,
|
||||
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
|
||||
):
|
||||
"""更新插件
|
||||
"""
|
||||
await self.installer.update_plugin(plugin_name, plugin_source)
|
||||
await self.installer.update_plugin(plugin_name, plugin_source, task_context)
|
||||
|
||||
plugin_container = self.get_plugin_by_name(plugin_name)
|
||||
|
||||
@@ -120,6 +121,7 @@ class PluginManager:
|
||||
new_version="HEAD"
|
||||
)
|
||||
|
||||
|
||||
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
|
||||
"""通过插件名获取插件
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user