feat: pipeline invoking

This commit is contained in:
Junyan Qin
2025-03-28 23:42:41 +08:00
parent f8127eb585
commit 5ff59f1b07
7 changed files with 185 additions and 164 deletions

View File

@@ -50,14 +50,19 @@ class Controller:
continue
if selected_query:
async def _process_query(selected_query):
async def _process_query(selected_query: entities.Query):
async with self.semaphore: # 总并发上限
await self.process_query(selected_query)
# find pipeline
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(selected_query.pipeline_uuid)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
self.ap.task_mgr.create_task(
_process_query(selected_query),
kind="query",
@@ -70,127 +75,6 @@ class Controller:
self.ap.logger.error(f"控制器循环出错: {e}")
self.ap.logger.error(f"Traceback: {traceback.format_exc()}")
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(
platform_message.Plain(result.user_notice)
)
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(
*result.user_notice
)
await self.ap.platform_mgr.send(
query.message_event,
result.user_notice,
query.adapter
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
if result.console_notice:
self.ap.logger.info(result.console_notice)
if result.error_notice:
self.ap.logger.error(result.error_notice)
async def _execute_from_stage(
self,
stage_index: int,
query: entities.Query,
):
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
如何看懂这里为什么这么写?
去问 GPT-4:
Q1: 现在有一个责任链其中有多个stagequery对象在其中传递stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None]
如果返回的是生成器需要挨个生成result检查是否result中是否要求继续如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了就继续生成下一个result
调用后续的stage直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
Q2: 不是这样的你可能理解有误。如果我们责任链上有这些Stage
A B C D E F G
如果所有的stage都返回Result且所有Result都要求继续那么执行顺序是
A B C D E F G
现在假设C返回的是AsyncGenerator那么执行顺序是
A B C D E F G C D E F G C D E F G ...
Q3: 但是如果不止一个stage会返回生成器呢
"""
i = stage_index
while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
i += 1
async def process_query(self, query: entities.Query):
"""处理请求
"""
try:
# ======== 触发 MessageReceived 事件 ========
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query
)
)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f"Processing query {query}")
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
finally:
self.ap.logger.debug(f"Query {query} processed")
async def run(self):
"""运行控制器
"""