diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 71ec995b..a34eb082 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -57,6 +57,12 @@ class Query(pydantic.BaseModel): message_chain: platform_message.MessageChain """消息链,platform收到的原始消息链""" + pipeline_uuid: typing.Optional[str] = None + """流水线UUID。""" + + pipeline_config: typing.Optional[dict[str, typing.Any]] = None + """流水线配置,由 Pipeline 在运行开始时设置。""" + adapter: msadapter.MessagePlatformAdapter """消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器""" diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 807bec05..5d66f49e 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -50,14 +50,19 @@ class Controller: continue if selected_query: - async def _process_query(selected_query): + + async def _process_query(selected_query: entities.Query): async with self.semaphore: # 总并发上限 - await self.process_query(selected_query) + # find pipeline + pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(selected_query.pipeline_uuid) + if pipeline: + await pipeline.run(selected_query) async with self.ap.query_pool: (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() # 通知其他协程,有新的请求可以处理了 self.ap.query_pool.condition.notify_all() + self.ap.task_mgr.create_task( _process_query(selected_query), kind="query", @@ -70,127 +75,6 @@ class Controller: self.ap.logger.error(f"控制器循环出错: {e}") self.ap.logger.error(f"Traceback: {traceback.format_exc()}") - async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): - """检查输出 - """ - if result.user_notice: - # 处理str类型 - - if isinstance(result.user_notice, str): - result.user_notice = platform_message.MessageChain( - platform_message.Plain(result.user_notice) - ) - elif isinstance(result.user_notice, list): - result.user_notice = platform_message.MessageChain( - *result.user_notice - ) - - await self.ap.platform_mgr.send( - query.message_event, - result.user_notice, - query.adapter - ) - if result.debug_notice: - self.ap.logger.debug(result.debug_notice) - if result.console_notice: - self.ap.logger.info(result.console_notice) - if result.error_notice: - self.ap.logger.error(result.error_notice) - - async def _execute_from_stage( - self, - stage_index: int, - query: entities.Query, - ): - """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 - - 如何看懂这里为什么这么写? - 去问 GPT-4: - Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None], - 如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result, - 调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器 - Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage: - - A B C D E F G - - 如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是: - - A B C D E F G - - 现在假设C返回的是AsyncGenerator,那么执行顺序是: - - A B C D E F G C D E F G C D E F G ... - Q3: 但是如果不止一个stage会返回生成器呢? - """ - i = stage_index - - while i < len(self.ap.stage_mgr.stage_containers): - stage_container = self.ap.stage_mgr.stage_containers[i] - - query.current_stage = stage_container # 标记到 Query 对象里 - - result = stage_container.inst.process(query, stage_container.inst_name) - - if isinstance(result, typing.Coroutine): - result = await result - - if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") - await self._check_output(query, result) - - if result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") - break - elif result.result_type == pipeline_entities.ResultType.CONTINUE: - query = result.new_query - elif isinstance(result, typing.AsyncGenerator): # 生成器 - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen") - - async for sub_result in result: - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") - await self._check_output(query, sub_result) - - if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") - break - elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: - query = sub_result.new_query - await self._execute_from_stage(i + 1, query) - break - - i += 1 - - async def process_query(self, query: entities.Query): - """处理请求 - """ - try: - - # ======== 触发 MessageReceived 事件 ======== - event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived - - event_ctx = await self.ap.plugin_mgr.emit_event( - event=event_type( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - message_chain=query.message_chain, - query=query - ) - ) - - if event_ctx.is_prevented_default(): - return - - self.ap.logger.debug(f"Processing query {query}") - - await self._execute_from_stage(0, query) - except Exception as e: - inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' - self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}") - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - finally: - self.ap.logger.debug(f"Query {query} processed") - async def run(self): """运行控制器 """ diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index a805e5cd..17189a52 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -1,13 +1,16 @@ from __future__ import annotations import typing +import traceback import sqlalchemy from ..core import app, entities +from . import entities as pipeline_entities from ..entity.persistence import pipeline as persistence_pipeline from . import stagemgr, stage - +from ..platform.types import message as platform_message, events as platform_events +from ..plugin import events class RuntimePipeline: """运行时流水线""" @@ -25,8 +28,130 @@ class RuntimePipeline: self.pipeline_entity = pipeline_entity self.stage_containers = stage_containers - async def run(self): - pass + async def run(self, query: entities.Query): + query.pipeline_config = self.pipeline_entity.config + await self.process_query(query) + + async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): + """检查输出 + """ + if result.user_notice: + # 处理str类型 + + if isinstance(result.user_notice, str): + result.user_notice = platform_message.MessageChain( + platform_message.Plain(result.user_notice) + ) + elif isinstance(result.user_notice, list): + result.user_notice = platform_message.MessageChain( + *result.user_notice + ) + + await self.ap.platform_mgr.send( + query.message_event, + result.user_notice, + query.adapter + ) + if result.debug_notice: + self.ap.logger.debug(result.debug_notice) + if result.console_notice: + self.ap.logger.info(result.console_notice) + if result.error_notice: + self.ap.logger.error(result.error_notice) + + async def _execute_from_stage( + self, + stage_index: int, + query: entities.Query, + ): + """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 + + 如何看懂这里为什么这么写? + 去问 GPT-4: + Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None], + 如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result, + 调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器 + Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage: + + A B C D E F G + + 如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是: + + A B C D E F G + + 现在假设C返回的是AsyncGenerator,那么执行顺序是: + + A B C D E F G C D E F G C D E F G ... + Q3: 但是如果不止一个stage会返回生成器呢? + """ + i = stage_index + + while i < len(self.stage_containers): + stage_container = self.stage_containers[i] + + query.current_stage = stage_container # 标记到 Query 对象里 + + result = stage_container.inst.process(query, stage_container.inst_name) + + if isinstance(result, typing.Coroutine): + result = await result + + if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") + await self._check_output(query, result) + + if result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif result.result_type == pipeline_entities.ResultType.CONTINUE: + query = result.new_query + elif isinstance(result, typing.AsyncGenerator): # 生成器 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen") + + async for sub_result in result: + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") + await self._check_output(query, sub_result) + + if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: + query = sub_result.new_query + await self._execute_from_stage(i + 1, query) + break + + i += 1 + + async def process_query(self, query: entities.Query): + """处理请求 + """ + try: + + # ======== 触发 MessageReceived 事件 ======== + event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_type( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + message_chain=query.message_chain, + query=query + ) + ) + + if event_ctx.is_prevented_default(): + return + + self.ap.logger.debug(f"Processing query {query}") + + await self._execute_from_stage(0, query) + except Exception as e: + inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' + self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}") + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + finally: + self.ap.logger.debug(f"Query {query} processed") class PipelineManager: diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index e358d249..d0c86e31 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -33,7 +33,8 @@ class QueryPool: sender_id: typing.Union[int, str], message_event: platform_events.MessageEvent, message_chain: platform_message.MessageChain, - adapter: msadapter.MessagePlatformAdapter + adapter: msadapter.MessagePlatformAdapter, + pipeline_uuid: str ) -> entities.Query: async with self.condition: query = entities.Query( @@ -43,6 +44,7 @@ class QueryPool: sender_id=sender_id, message_event=message_event, message_chain=message_chain, + pipeline_uuid=pipeline_uuid, resp_messages=[], resp_message_chain=[], adapter=adapter diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 210ee9ad..81f15655 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -50,6 +50,41 @@ class RuntimeBot: self.adapter = adapter self.task_context = taskmgr.TaskContext() + async def initialize(self): + + async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter, + pipeline_uuid=self.bot_entity.use_pipeline_uuid + ) + + async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter): + + await self.ap.query_pool.add_query( + launcher_type=core_entities.LauncherTypes.GROUP, + launcher_id=event.group.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter, + pipeline_uuid=self.bot_entity.use_pipeline_uuid + ) + + self.adapter.register_listener( + platform_events.FriendMessage, + on_friend_message + ) + self.adapter.register_listener( + platform_events.GroupMessage, + on_group_message + ) + async def run(self): async def exception_wrapper(): @@ -135,49 +170,20 @@ class PlatformManager: bot_entity = persistence_bot.Bot(**bot_entity._mapping) elif isinstance(bot_entity, dict): bot_entity = persistence_bot.Bot(**bot_entity) - - async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.PERSON, - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) - - async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.GROUP, - launcher_id=event.group.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) adapter_inst = self.adapter_dict[bot_entity.adapter]( bot_entity.adapter_config, self.ap ) - adapter_inst.register_listener( - platform_events.FriendMessage, - on_friend_message - ) - adapter_inst.register_listener( - platform_events.GroupMessage, - on_group_message - ) - runtime_bot = RuntimeBot( ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst ) + await runtime_bot.initialize() + self.bots.append(runtime_bot) return runtime_bot diff --git a/templates/metadata/pipeline/ai.yaml b/templates/metadata/pipeline/ai.yaml index 38d579c0..74566a31 100644 --- a/templates/metadata/pipeline/ai.yaml +++ b/templates/metadata/pipeline/ai.yaml @@ -42,7 +42,7 @@ stages: zh_CN: 模型 type: select required: true - scope: llm-model + scope: /provider/models/llm - name: max-round label: en_US: Max Round @@ -54,11 +54,9 @@ stages: label: en_US: Prompt zh_CN: 提示词 - type: array + type: string required: true - default: [] - items: - type: string + default: "You are a helpful assistant." - name: dify-service-api label: en_US: Dify Service API diff --git a/templates/metadata/pipeline/safety.yaml b/templates/metadata/pipeline/safety.yaml index d19913af..09f8025b 100644 --- a/templates/metadata/pipeline/safety.yaml +++ b/templates/metadata/pipeline/safety.yaml @@ -46,7 +46,7 @@ stages: zh_CN: 窗口长度(秒) type: integer required: true - default: 10 + default: 60 - name: limitation label: en_US: Limitation