From ee3da8aa17e63448d526a1cb8aeb5f8c587f001b Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Wed, 2 Jul 2025 11:04:03 +0800 Subject: [PATCH] feat: adapt more events --- pkg/pipeline/cntfilter/cntfilter.py | 4 ++-- pkg/pipeline/controller.py | 6 +++--- pkg/pipeline/pipelinemgr.py | 2 +- pkg/pipeline/preproc/preproc.py | 11 +++++++++-- pkg/pipeline/process/handlers/chat.py | 11 +++++++++-- pkg/pipeline/wrapper/wrapper.py | 13 +++++++++++-- pkg/provider/entities.py | 2 +- pkg/provider/session/sessionmgr.py | 2 +- 8 files changed, 37 insertions(+), 14 deletions(-) diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index c40a2042..26b00411 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -66,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage): if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg': return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) if not message.strip(): - return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: for filter in self.filter_chain: if filter_entities.EnableStage.PRE in filter.enable_stages: @@ -85,7 +85,7 @@ class ContentFilterStage(stage.PipelineStage): elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 message = result.replacement - query.message_chain = platform_message.MessageChain(platform_message.Plain(text=message)) + query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)]) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 11bd8d46..b1dde4a6 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -35,9 +35,9 @@ class Controller: session = await self.ap.sess_mgr.get_session(query) self.ap.logger.debug(f'Checking query {query} session {session}') - if not session.semaphore.locked(): + if not session._semaphore.locked(): selected_query = query - await session.semaphore.acquire() + await session._semaphore.acquire() break @@ -62,7 +62,7 @@ class Controller: await pipeline.run(selected_query) async with self.ap.query_pool: - (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() + (await self.ap.sess_mgr.get_session(selected_query))._semaphore.release() # 通知其他协程,有新的请求可以处理了 self.ap.query_pool.condition.notify_all() diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index b32ad98d..6abb3972 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -90,7 +90,7 @@ class RuntimePipeline: # 处理str类型 if isinstance(result.user_notice, str): - result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice)) + result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)]) elif isinstance(result.user_notice, list): result.user_notice = platform_message.MessageChain(*result.user_notice) diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 894ceebf..344a136c 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -4,9 +4,10 @@ import datetime from .. import stage, entities from langbot_plugin.api.entities.builtin.provider import message as provider_message -from ...plugin import events +import langbot_plugin.api.entities.events as events import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.context as event_context @stage.stage_class('PreProcessor') @@ -108,7 +109,7 @@ class PreProcessor(stage.PipelineStage): query.user_message = provider_message.Message(role='user', content=content_list) # =========== 触发事件 PromptPreProcessing - event_ctx = await self.ap.plugin_mgr.emit_event( + event_ctx = event_context.EventContext( event=events.PromptPreProcessing( session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', default_prompt=query.prompt.messages, @@ -117,6 +118,12 @@ class PreProcessor(stage.PipelineStage): ) ) + event_ctx_result = await self.ap.plugin_connector.handler.emit_event( + event_ctx.model_dump(serialize_as_any=True) + ) + + event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context']) + query.prompt.messages = event_ctx.event.default_prompt query.messages = event_ctx.event.prompt diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 717727d0..24f4553d 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -7,13 +7,14 @@ import traceback from .. import handler from ... import entities from ....provider import runner as runner_module -from ....plugin import events import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.events as events from ....utils import importutil from ....provider import runners import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.context as event_context importutil.import_modules_in_pkg(runners) @@ -35,7 +36,7 @@ class ChatMessageHandler(handler.MessageHandler): else events.GroupNormalMessageReceived ) - event_ctx = await self.ap.plugin_mgr.emit_event( + event_ctx = event_context.EventContext( event=event_class( launcher_type=query.launcher_type.value, launcher_id=query.launcher_id, @@ -45,6 +46,12 @@ class ChatMessageHandler(handler.MessageHandler): ) ) + event_ctx_result = await self.ap.plugin_connector.handler.emit_event( + event_ctx.model_dump(serialize_as_any=True) + ) + + event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context']) + if event_ctx.is_prevented_default(): if event_ctx.event.reply is not None: mc = platform_message.MessageChain(event_ctx.event.reply) diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 2c6e218e..3608d616 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -4,9 +4,11 @@ import typing from .. import entities from .. import stage -from ...plugin import events + import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.context as event_context +import langbot_plugin.api.entities.events as events @stage.stage_class('ResponseWrapper') @@ -57,7 +59,7 @@ class ResponseWrapper(stage.PipelineStage): reply_text = str(result.get_content_platform_message_chain()) # ============= 触发插件事件 =============== - event_ctx = await self.ap.plugin_mgr.emit_event( + event_ctx = event_context.EventContext( event=events.NormalMessageResponded( launcher_type=query.launcher_type.value, launcher_id=query.launcher_id, @@ -72,6 +74,13 @@ class ResponseWrapper(stage.PipelineStage): query=query, ) ) + + serialized_event_ctx = event_ctx.model_dump(serialize_as_any=True) + + event_ctx_result = await self.ap.plugin_connector.handler.emit_event(serialized_event_ctx) + + event_ctx = event_context.EventContext.parse_from_dict(event_ctx_result['event_context']) + if event_ctx.is_prevented_default(): yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 6de61e39..b03ece38 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -96,7 +96,7 @@ class Message(pydantic.BaseModel): if self.content is None: return None elif isinstance(self.content, str): - return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)]) + return platform_message.MessageChain([platform_message.Plain(text=(prefix_text + self.content))]) elif isinstance(self.content, list): mc = [] for ce in self.content: diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index 03465e0b..11d0254c 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -33,8 +33,8 @@ class SessionManager: session = provider_session.Session( launcher_type=query.launcher_type, launcher_id=query.launcher_id, - semaphore=asyncio.Semaphore(session_concurrency), ) + session._semaphore = asyncio.Semaphore(session_concurrency) self.session_list.append(session) return session