From ff71c1cede774c5e2d11414785baf6fc8e843cee Mon Sep 17 00:00:00 2001 From: BiFangKNT <1320414964@qq.com> Date: Fri, 26 Jun 2026 10:40:00 +0800 Subject: [PATCH] feat(plugin): report deferred response delivery failures --- .../pkg/pipeline/plugin_diagnostics.py | 276 ++++++++++++++++++ .../pkg/pipeline/process/handlers/chat.py | 8 + .../pkg/pipeline/process/handlers/command.py | 8 + src/langbot/pkg/pipeline/respback/respback.py | 44 ++- src/langbot/pkg/pipeline/wrapper/wrapper.py | 21 ++ src/langbot/pkg/plugin/connector.py | 14 + src/langbot/pkg/plugin/handler.py | 21 ++ tests/integration/pipeline/test_full_flow.py | 94 ++++++ tests/unit_tests/pipeline/test_wrapper.py | 142 +++++++++ .../plugin/test_connector_methods.py | 126 ++++++++ tests/unit_tests/plugin/test_handler.py | 30 ++ 11 files changed, 770 insertions(+), 14 deletions(-) create mode 100644 src/langbot/pkg/pipeline/plugin_diagnostics.py diff --git a/src/langbot/pkg/pipeline/plugin_diagnostics.py b/src/langbot/pkg/pipeline/plugin_diagnostics.py new file mode 100644 index 000000000..951ae6e5c --- /dev/null +++ b/src/langbot/pkg/pipeline/plugin_diagnostics.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import traceback +import weakref +from dataclasses import dataclass, field +from typing import Any + +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.platform.message as platform_message + + +@dataclass(frozen=True) +class PluginResponseSource: + plugin: dict[str, str] + event_name: str | None = None + is_approximate: bool = False + + +@dataclass +class QueryDiagnosticState: + pending_by_chain_id: dict[int, list[PluginResponseSource]] = field(default_factory=dict) + by_response_index: dict[int, list[PluginResponseSource]] = field(default_factory=dict) + finalizer: weakref.finalize | None = None + + +_QUERY_STATES: dict[int, QueryDiagnosticState] = {} + + +def record_plugin_response_source( + query: pipeline_query.Query, + response_index: int, + response_sources: list[dict[str, Any]] | None, + emitted_plugins: list[dict[str, Any]] | None = None, + event_name: str | None = None, +) -> None: + plugin_sources = _build_plugin_sources(response_sources, emitted_plugins, event_name) + if not plugin_sources: + return + state = _get_or_create_query_state(query) + state.by_response_index[response_index] = plugin_sources + + +def record_last_plugin_response_source( + query: pipeline_query.Query, + response_sources: list[dict[str, Any]] | None, + emitted_plugins: list[dict[str, Any]] | None = None, + event_name: str | None = None, +) -> None: + record_plugin_response_source( + query, + len(query.resp_message_chain) - 1, + response_sources, + emitted_plugins, + event_name, + ) + + +def record_pending_plugin_response_source( + query: pipeline_query.Query, + message_chain: platform_message.MessageChain, + response_sources: list[dict[str, Any]] | None, + emitted_plugins: list[dict[str, Any]] | None = None, + event_name: str | None = None, +) -> None: + plugin_sources = _build_plugin_sources(response_sources, emitted_plugins, event_name) + if not plugin_sources: + return + state = _get_or_create_query_state(query) + state.pending_by_chain_id[id(message_chain)] = plugin_sources + + +def consume_pending_plugin_response_source( + query: pipeline_query.Query, + message_chain: platform_message.MessageChain, + response_index: int, +) -> None: + state = _get_query_state(query) + if state is None: + return + source = state.pending_by_chain_id.pop(id(message_chain), None) + if source is None: + return + state.by_response_index[response_index] = source + + +def clear_response_source(query: pipeline_query.Query, response_index: int) -> None: + state = _get_query_state(query) + if state is None: + return + state.by_response_index.pop(response_index, None) + _discard_query_state_if_empty(query) + + +async def notify_response_delivery_failure( + ap: Any, + query: pipeline_query.Query, + response_index: int, + message_chain: platform_message.MessageChain, + error: Exception, +) -> None: + try: + plugin_refs = _get_response_sources(query, response_index) + if not plugin_refs: + return + connector = getattr(ap, 'plugin_connector', None) + if connector is None or not hasattr(connector, 'notify_plugin_diagnostic'): + return + for source in plugin_refs: + payload = _build_delivery_failure_payload( + plugin_ref=source.plugin, + event_name=source.event_name, + is_approximate=source.is_approximate, + query=query, + response_index=response_index, + message_chain=message_chain, + error=error, + ) + try: + await connector.notify_plugin_diagnostic(payload) + except Exception as diag_error: + _debug(ap, f'Plugin diagnostic forwarding failed: {diag_error}') + except Exception as diag_error: + _debug(ap, f'Plugin diagnostic forwarding skipped: {diag_error}') + + +def get_emitted_plugins(event_ctx: Any) -> list[dict[str, Any]]: + emitted_plugins = getattr(event_ctx, '_emitted_plugins', []) + return emitted_plugins if isinstance(emitted_plugins, list) else [] + + +def get_response_sources(event_ctx: Any) -> list[dict[str, Any]] | None: + event_attrs = vars(event_ctx) + if '_response_sources' not in event_attrs: + return None + response_sources = event_attrs['_response_sources'] + return response_sources if isinstance(response_sources, list) else [] + + +def _get_or_create_query_state(query: pipeline_query.Query) -> QueryDiagnosticState: + query_key = id(query) + state = _QUERY_STATES.get(query_key) + if state is not None: + return state + + state = QueryDiagnosticState() + try: + state.finalizer = weakref.finalize(query, _discard_query_state, query_key) + except TypeError: + state.finalizer = None + _QUERY_STATES[query_key] = state + return state + + +def _get_query_state(query: pipeline_query.Query) -> QueryDiagnosticState | None: + return _QUERY_STATES.get(id(query)) + + +def _discard_query_state(query_key: int) -> None: + _QUERY_STATES.pop(query_key, None) + + +def _discard_query_state_if_empty(query: pipeline_query.Query) -> None: + query_key = id(query) + state = _QUERY_STATES.get(query_key) + if state is None: + return + if state.pending_by_chain_id or state.by_response_index: + return + if state.finalizer is not None: + state.finalizer.detach() + _discard_query_state(query_key) + + +def _get_response_sources( + query: pipeline_query.Query, + response_index: int, +) -> list[PluginResponseSource]: + state = _get_query_state(query) + if state is None: + return [] + return state.by_response_index.get(response_index, []) + + +def _extract_plugin_ref(plugin: Any) -> dict[str, str] | None: + manifest = plugin.get('manifest') if isinstance(plugin, dict) else None + metadata = manifest.get('metadata') if isinstance(manifest, dict) else None + if not isinstance(metadata, dict): + return None + author = metadata.get('author') + name = metadata.get('name') + if not author or not name: + return None + return {'author': str(author), 'name': str(name)} + + +def _extract_response_source_plugin_ref(source: Any) -> dict[str, str] | None: + if not isinstance(source, dict): + return None + if source.get('kind') != 'reply_message_chain': + return None + plugin_ref = source.get('plugin') + if not isinstance(plugin_ref, dict): + return None + author = plugin_ref.get('author') + name = plugin_ref.get('name') + if not author or not name: + return None + return {'author': str(author), 'name': str(name)} + + +def _build_plugin_sources( + response_sources: list[dict[str, Any]] | None, + emitted_plugins: list[dict[str, Any]] | None, + event_name: str | None, +) -> list[PluginResponseSource]: + if response_sources is not None: + plugin_refs = [_extract_response_source_plugin_ref(source) for source in response_sources] + return [ + PluginResponseSource(plugin=plugin, event_name=event_name) + for plugin in plugin_refs + if plugin is not None + ] + + if emitted_plugins: + plugin_refs = [_extract_plugin_ref(plugin) for plugin in emitted_plugins] + return [ + PluginResponseSource(plugin=plugin, event_name=event_name, is_approximate=True) + for plugin in plugin_refs + if plugin is not None + ] + return [] + + +def _debug(ap: Any, message: str) -> None: + logger = getattr(ap, 'logger', None) + if logger is not None: + logger.debug(message) + + +def _build_delivery_failure_payload( + plugin_ref: dict[str, str], + event_name: str | None, + is_approximate: bool, + query: pipeline_query.Query, + response_index: int, + message_chain: platform_message.MessageChain, + error: Exception, +) -> dict[str, Any]: + details: dict[str, Any] = { + 'message_component_types': [component.__class__.__name__ for component in message_chain], + 'message_preview': str(message_chain)[:200], + } + if is_approximate: + details['attribution_warning'] = ( + 'This diagnostic was delivered to all plugins that handled the event because the ' + 'plugin runtime did not report the exact reply_message_chain source.' + ) + + return { + 'level': 'ERROR', + 'code': 'response_delivery_failed', + 'message': 'Failed to deliver a plugin-provided response message.', + 'plugin': plugin_ref, + 'query': { + 'query_id': query.query_id, + 'event_name': event_name or query.message_event.__class__.__name__, + 'stage': query.current_stage_name or 'SendResponseBackStage', + 'response_index': response_index, + }, + 'details': details, + 'delivery': { + 'error_type': error.__class__.__name__, + 'error_message': str(error), + 'traceback': traceback.format_exception_only(type(error), error)[-1].strip(), + }, + } diff --git a/src/langbot/pkg/pipeline/process/handlers/chat.py b/src/langbot/pkg/pipeline/process/handlers/chat.py index 488e19216..4e5c14ea4 100644 --- a/src/langbot/pkg/pipeline/process/handlers/chat.py +++ b/src/langbot/pkg/pipeline/process/handlers/chat.py @@ -9,6 +9,7 @@ from datetime import datetime from .. import handler from ... import entities +from ... import plugin_diagnostics from ....provider import runner as runner_module import langbot_plugin.api.entities.events as events @@ -58,6 +59,13 @@ class ChatMessageHandler(handler.MessageHandler): if event_ctx.is_prevented_default(): if event_ctx.event.reply_message_chain is not None: mc = event_ctx.event.reply_message_chain + plugin_diagnostics.record_pending_plugin_response_source( + query, + mc, + plugin_diagnostics.get_response_sources(event_ctx), + plugin_diagnostics.get_emitted_plugins(event_ctx), + event.event_name, + ) query.resp_messages.append(mc) yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/src/langbot/pkg/pipeline/process/handlers/command.py b/src/langbot/pkg/pipeline/process/handlers/command.py index 6d686acd4..09fa5379b 100644 --- a/src/langbot/pkg/pipeline/process/handlers/command.py +++ b/src/langbot/pkg/pipeline/process/handlers/command.py @@ -4,6 +4,7 @@ import typing from .. import handler from ... import entities +from ... import plugin_diagnostics import langbot_plugin.api.entities.builtin.provider.message as provider_message import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @@ -52,6 +53,13 @@ class CommandHandler(handler.MessageHandler): if event_ctx.is_prevented_default(): if event_ctx.event.reply_message_chain is not None: mc = event_ctx.event.reply_message_chain + plugin_diagnostics.record_pending_plugin_response_source( + query, + mc, + plugin_diagnostics.get_response_sources(event_ctx), + plugin_diagnostics.get_emitted_plugins(event_ctx), + event.event_name, + ) query.resp_messages.append(mc) diff --git a/src/langbot/pkg/pipeline/respback/respback.py b/src/langbot/pkg/pipeline/respback/respback.py index 574404bcf..0c85fbb45 100644 --- a/src/langbot/pkg/pipeline/respback/respback.py +++ b/src/langbot/pkg/pipeline/respback/respback.py @@ -9,6 +9,7 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.provider.message as provider_message from .. import stage, entities +from .. import plugin_diagnostics import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @@ -39,20 +40,35 @@ class SendResponseBackStage(stage.PipelineStage): has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages) # TODO 命令与流式的兼容性问题 - if await query.adapter.is_stream_output_supported() and has_chunks: - is_final = [msg.is_final for msg in query.resp_messages][0] - await query.adapter.reply_message_chunk( - message_source=query.message_event, - bot_message=query.resp_messages[-1], - message=query.resp_message_chain[-1], - quote_origin=quote_origin, - is_final=is_final, - ) - else: - await query.adapter.reply_message( - message_source=query.message_event, - message=query.resp_message_chain[-1], - quote_origin=quote_origin, + response_index = len(query.resp_message_chain) - 1 + message_chain = query.resp_message_chain[-1] + + try: + if await query.adapter.is_stream_output_supported() and has_chunks: + is_final = [msg.is_final for msg in query.resp_messages][0] + await query.adapter.reply_message_chunk( + message_source=query.message_event, + bot_message=query.resp_messages[-1], + message=message_chain, + quote_origin=quote_origin, + is_final=is_final, + ) + else: + await query.adapter.reply_message( + message_source=query.message_event, + message=message_chain, + quote_origin=quote_origin, + ) + except Exception as e: + await plugin_diagnostics.notify_response_delivery_failure( + self.ap, + query, + response_index, + message_chain, + e, ) + plugin_diagnostics.clear_response_source(query, response_index) + raise + plugin_diagnostics.clear_response_source(query, response_index) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/src/langbot/pkg/pipeline/wrapper/wrapper.py b/src/langbot/pkg/pipeline/wrapper/wrapper.py index a158c1840..50db693d4 100644 --- a/src/langbot/pkg/pipeline/wrapper/wrapper.py +++ b/src/langbot/pkg/pipeline/wrapper/wrapper.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing from .. import entities +from .. import plugin_diagnostics from .. import stage import langbot_plugin.api.entities.builtin.platform.message as platform_message @@ -78,6 +79,11 @@ class ResponseWrapper(stage.PipelineStage): # 如果 resp_messages[-1] 已经是 MessageChain 了 if isinstance(query.resp_messages[-1], platform_message.MessageChain): query.resp_message_chain.append(query.resp_messages[-1]) + plugin_diagnostics.consume_pending_plugin_response_source( + query, + query.resp_messages[-1], + len(query.resp_message_chain) - 1, + ) yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) @@ -129,8 +135,10 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply_message_chain is not None: reply_chain = event_ctx.event.reply_message_chain + is_plugin_reply = True else: reply_chain = result.get_content_platform_message_chain() + is_plugin_reply = False # Attach files the agent produced in the sandbox # outbox, but only on the terminal assistant message. @@ -138,6 +146,13 @@ class ResponseWrapper(stage.PipelineStage): await self._append_outbound_attachments(query, reply_chain) query.resp_message_chain.append(reply_chain) + if is_plugin_reply: + plugin_diagnostics.record_last_plugin_response_source( + query, + plugin_diagnostics.get_response_sources(event_ctx), + plugin_diagnostics.get_emitted_plugins(event_ctx), + event.event_name, + ) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -180,6 +195,12 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply_message_chain is not None: query.resp_message_chain.append(event_ctx.event.reply_message_chain) + plugin_diagnostics.record_last_plugin_response_source( + query, + plugin_diagnostics.get_response_sources(event_ctx), + plugin_diagnostics.get_emitted_plugins(event_ctx), + event.event_name, + ) else: query.resp_message_chain.append( diff --git a/src/langbot/pkg/plugin/connector.py b/src/langbot/pkg/plugin/connector.py index 9cef9cb67..6075d4b68 100644 --- a/src/langbot/pkg/plugin/connector.py +++ b/src/langbot/pkg/plugin/connector.py @@ -737,6 +737,8 @@ class PluginRuntimeConnector(ManagedRuntimeConnector): event_ctx = context.EventContext.from_event(event) if not self.is_enable_plugin: + event_ctx._emitted_plugins = [] + event_ctx._response_sources = [] return event_ctx # Pass include_plugins to runtime for filtering @@ -745,9 +747,21 @@ class PluginRuntimeConnector(ManagedRuntimeConnector): ) event_ctx = context.EventContext.model_validate(event_ctx_result['event_context']) + event_ctx._emitted_plugins = event_ctx_result.get('emitted_plugins', []) + if 'response_sources' in event_ctx_result: + event_ctx._response_sources = event_ctx_result['response_sources'] return event_ctx + async def notify_plugin_diagnostic(self, diagnostic: dict[str, Any]) -> None: + """Best-effort diagnostic forwarding to the plugin runtime.""" + if not self.is_enable_plugin: + return + try: + await self.handler.notify_plugin_diagnostic(diagnostic) + except Exception as e: + self.ap.logger.debug(f'Plugin diagnostic forwarding skipped: {e}') + async def list_tools(self, bound_plugins: list[str] | None = None) -> list[ComponentManifest]: if not self.is_enable_plugin: return [] diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index 2c4217910..dcfb006b5 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -26,6 +26,15 @@ from ..core import app from ..utils import constants +class _RawAction: + def __init__(self, value: str): + self.value = value + + +def _langbot_to_runtime_action(enum_name: str, fallback_value: str) -> Any: + return getattr(LangBotToRuntimeAction, enum_name, _RawAction(fallback_value)) + + def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse: """Create a clean error response for RAG operations. @@ -923,6 +932,18 @@ class RuntimeConnectionHandler(handler.Handler): return result + async def notify_plugin_diagnostic(self, diagnostic: dict[str, Any]) -> dict[str, Any]: + """Notify the plugin runtime about a best-effort plugin diagnostic. + + This intentionally uses the raw protocol string instead of a SDK enum so + LangBot can keep running with older langbot-plugin versions. + """ + return await self.call_action( + _langbot_to_runtime_action('PLUGIN_DIAGNOSTIC', 'plugin_diagnostic'), + diagnostic, + timeout=5, + ) + async def list_tools(self, include_plugins: list[str] | None = None) -> list[dict[str, Any]]: """List tools""" result = await self.call_action( diff --git a/tests/integration/pipeline/test_full_flow.py b/tests/integration/pipeline/test_full_flow.py index 6aa704436..767594c33 100644 --- a/tests/integration/pipeline/test_full_flow.py +++ b/tests/integration/pipeline/test_full_flow.py @@ -662,6 +662,100 @@ class TestSendResponseBackStage: assert len(outbound) == 1 assert outbound[0]['type'] == 'reply' + @pytest.mark.asyncio + async def test_send_response_failure_notifies_plugin_diagnostic(self, pipeline_app): + """Plugin-provided deferred replies should report delivery failures.""" + from langbot.pkg.pipeline import plugin_diagnostics + from langbot.pkg.pipeline.respback import respback + from tests.factories.message import text_chain + from langbot_plugin.api.entities.builtin.provider.message import Message + + query = text_query('hello') + query.adapter.reply_message.side_effect = RuntimeError('send failed') + query.pipeline_config = create_minimal_pipeline_config() + query.current_stage_name = 'SendResponseBackStage' + query.resp_messages = [Message(role='assistant', content='test response')] + query.resp_message_chain = [text_chain('test response')] + plugin_diagnostics.record_plugin_response_source( + query, + 0, + [ + { + 'kind': 'reply_message_chain', + 'plugin': {'author': 'tester', 'name': 'demo'}, + } + ], + [{'manifest': {'metadata': {'author': 'observer', 'name': 'not-reply-source'}}}], + 'NormalMessageResponded', + ) + pipeline_app.plugin_connector.notify_plugin_diagnostic = AsyncMock() + + respback_stage = respback.SendResponseBackStage(pipeline_app) + + with pytest.raises(RuntimeError, match='send failed'): + await respback_stage.process(query, 'SendResponseBackStage') + + pipeline_app.plugin_connector.notify_plugin_diagnostic.assert_awaited_once() + payload = pipeline_app.plugin_connector.notify_plugin_diagnostic.await_args.args[0] + assert payload['code'] == 'response_delivery_failed' + assert payload['plugin'] == {'author': 'tester', 'name': 'demo'} + assert payload['query']['event_name'] == 'NormalMessageResponded' + assert payload['delivery']['error_type'] == 'RuntimeError' + assert 'attribution_warning' not in payload['details'] + + @pytest.mark.asyncio + async def test_send_response_failure_warns_for_old_runtime_attribution(self, pipeline_app): + """Older plugin runtimes without response_sources should get approximate diagnostics.""" + from langbot.pkg.pipeline import plugin_diagnostics + from langbot.pkg.pipeline.respback import respback + from tests.factories.message import text_chain + from langbot_plugin.api.entities.builtin.provider.message import Message + + query = text_query('hello') + query.adapter.reply_message.side_effect = RuntimeError('send failed') + query.pipeline_config = create_minimal_pipeline_config() + query.resp_messages = [Message(role='assistant', content='test response')] + query.resp_message_chain = [text_chain('test response')] + plugin_diagnostics.record_plugin_response_source( + query, + 0, + None, + [{'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}}], + 'NormalMessageResponded', + ) + pipeline_app.plugin_connector.notify_plugin_diagnostic = AsyncMock() + + respback_stage = respback.SendResponseBackStage(pipeline_app) + + with pytest.raises(RuntimeError, match='send failed'): + await respback_stage.process(query, 'SendResponseBackStage') + + payload = pipeline_app.plugin_connector.notify_plugin_diagnostic.await_args.args[0] + assert payload['plugin'] == {'author': 'tester', 'name': 'demo'} + assert 'attribution_warning' in payload['details'] + + @pytest.mark.asyncio + async def test_send_response_failure_ignores_query_variable_spoofing(self, pipeline_app): + """Plugin-controlled query variables must not mask delivery failures.""" + from langbot.pkg.pipeline.respback import respback + from tests.factories.message import text_chain + from langbot_plugin.api.entities.builtin.provider.message import Message + + query = text_query('hello') + query.adapter.reply_message.side_effect = RuntimeError('send failed') + query.pipeline_config = create_minimal_pipeline_config() + query.resp_messages = [Message(role='assistant', content='test response')] + query.resp_message_chain = [text_chain('test response')] + query.variables['_plugin_response_sources'] = {0: ['malformed']} + pipeline_app.plugin_connector.notify_plugin_diagnostic = AsyncMock() + + respback_stage = respback.SendResponseBackStage(pipeline_app) + + with pytest.raises(RuntimeError, match='send failed'): + await respback_stage.process(query, 'SendResponseBackStage') + + pipeline_app.plugin_connector.notify_plugin_diagnostic.assert_not_called() + @pytest.mark.usefixtures('mock_circular_import_chain') class TestStageChainIntegration: diff --git a/tests/unit_tests/pipeline/test_wrapper.py b/tests/unit_tests/pipeline/test_wrapper.py index 8dea6c8bb..034c14115 100644 --- a/tests/unit_tests/pipeline/test_wrapper.py +++ b/tests/unit_tests/pipeline/test_wrapper.py @@ -36,6 +36,11 @@ def get_entities_module(): return import_module('langbot.pkg.pipeline.entities') +def get_plugin_diagnostics_module(): + """Lazy import for plugin diagnostic attribution helpers.""" + return import_module('langbot.pkg.pipeline.plugin_diagnostics') + + def make_wrapper_config(): """Create a pipeline config for wrapper tests.""" return { @@ -106,6 +111,45 @@ class TestResponseWrapperMessageChain: assert results[0].result_type == entities.ResultType.CONTINUE assert len(results[0].new_query.resp_message_chain) == 1 + @pytest.mark.asyncio + async def test_message_chain_direct_append_consumes_pending_plugin_source(self): + """MessageChain replies from earlier plugin events keep attribution.""" + wrapper = get_wrapper_module() + + app = FakeApp() + stage = wrapper.ResponseWrapper(app) + await stage.initialize(make_wrapper_config()) + + reply_chain = platform_message.MessageChain([platform_message.Plain(text='response')]) + query = text_query('hello') + query.pipeline_config = make_wrapper_config() + query.resp_messages = [reply_chain] + query.resp_message_chain = [] + plugin_diagnostics = get_plugin_diagnostics_module() + plugin_diagnostics.record_pending_plugin_response_source( + query, + reply_chain, + [ + { + 'kind': 'reply_message_chain', + 'plugin': {'author': 'tester', 'name': 'demo'}, + } + ], + [{'manifest': {'metadata': {'author': 'observer', 'name': 'not-reply-source'}}}], + 'PersonNormalMessageReceived', + ) + + results = [] + async for result in stage.process(query, 'ResponseWrapper'): + results.append(result) + + sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0) + assert sources[0].plugin == {'author': 'tester', 'name': 'demo'} + assert sources[0].event_name == 'PersonNormalMessageReceived' + assert sources[0].is_approximate is False + assert '_plugin_response_sources' not in query.variables + assert '_plugin_pending_response_sources' not in query.variables + class TestResponseWrapperCommand: """Tests for command response wrapping.""" @@ -421,6 +465,104 @@ class TestResponseWrapperCustomReply: chain = results[0].new_query.resp_message_chain[0] assert 'Custom reply' in str(chain) + @pytest.mark.asyncio + async def test_custom_reply_records_plugin_source(self): + """Plugin reply_message_chain should keep emitted plugin attribution.""" + wrapper = get_wrapper_module() + + app = FakeApp() + app.sess_mgr.get_session = AsyncMock(return_value=make_session()) + + custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')]) + mock_event_ctx = Mock() + mock_event_ctx.is_prevented_default = Mock(return_value=False) + mock_event_ctx.event = Mock() + mock_event_ctx.event.reply_message_chain = custom_chain + mock_event_ctx._emitted_plugins = [ + { + 'manifest': {'metadata': {'author': 'observer', 'name': 'not-reply-source'}}, + 'plugin_config': {'token': 'secret-token'}, + }, + ] + mock_event_ctx._response_sources = [ + { + 'kind': 'reply_message_chain', + 'plugin': {'author': 'tester', 'name': 'demo'}, + } + ] + app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + + stage = wrapper.ResponseWrapper(app) + pipeline_config = make_wrapper_config() + await stage.initialize(pipeline_config) + + query = text_query('hello') + query.pipeline_config = pipeline_config + query.resp_message_chain = [] + assistant_resp = Mock() + assistant_resp.role = 'assistant' + assistant_resp.content = 'Default reply' + assistant_resp.tool_calls = None + assistant_resp.get_content_platform_message_chain = Mock( + return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')]) + ) + query.resp_messages = [assistant_resp] + + results = [] + async for result in stage.process(query, 'ResponseWrapper'): + results.append(result) + + plugin_diagnostics = get_plugin_diagnostics_module() + sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0) + assert sources[0].plugin == {'author': 'tester', 'name': 'demo'} + assert sources[0].event_name == 'NormalMessageResponded' + assert sources[0].is_approximate is False + assert 'secret-token' not in str(sources) + assert '_plugin_response_sources' not in query.variables + + @pytest.mark.asyncio + async def test_custom_reply_falls_back_to_emitted_plugins_for_old_runtime(self): + """Older plugin runtimes without response_sources keep approximate attribution.""" + wrapper = get_wrapper_module() + + app = FakeApp() + app.sess_mgr.get_session = AsyncMock(return_value=make_session()) + + custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')]) + mock_event_ctx = Mock() + mock_event_ctx.is_prevented_default = Mock(return_value=False) + mock_event_ctx.event = Mock() + mock_event_ctx.event.reply_message_chain = custom_chain + mock_event_ctx._emitted_plugins = [ + {'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}}, + ] + app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx) + + stage = wrapper.ResponseWrapper(app) + pipeline_config = make_wrapper_config() + await stage.initialize(pipeline_config) + + query = text_query('hello') + query.pipeline_config = pipeline_config + query.resp_message_chain = [] + assistant_resp = Mock() + assistant_resp.role = 'assistant' + assistant_resp.content = 'Default reply' + assistant_resp.tool_calls = None + assistant_resp.get_content_platform_message_chain = Mock( + return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')]) + ) + query.resp_messages = [assistant_resp] + + results = [] + async for result in stage.process(query, 'ResponseWrapper'): + results.append(result) + + plugin_diagnostics = get_plugin_diagnostics_module() + sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0) + assert sources[0].plugin == {'author': 'tester', 'name': 'demo'} + assert sources[0].is_approximate is True + class TestResponseWrapperVariables: """Tests for bound plugins variable.""" diff --git a/tests/unit_tests/plugin/test_connector_methods.py b/tests/unit_tests/plugin/test_connector_methods.py index 5f09ce5a4..34cab5271 100644 --- a/tests/unit_tests/plugin/test_connector_methods.py +++ b/tests/unit_tests/plugin/test_connector_methods.py @@ -13,6 +13,8 @@ import pytest from unittest.mock import Mock, AsyncMock from importlib import import_module +from tests.factories import text_query + def get_connector_module(): """Lazy import to avoid circular import issues.""" @@ -132,6 +134,130 @@ class TestListPlugins: assert result[0]['debug'] is True +class TestPluginDiagnostics: + @pytest.mark.asyncio + async def test_emit_event_preserves_response_sources(self): + connector = create_mock_connector() + query = text_query('hello') + event = query.message_event + object.__setattr__(event, 'query', query) + connector_module = get_connector_module() + original_from_event = connector_module.context.EventContext.from_event + original_model_validate = connector_module.context.EventContext.model_validate + response_sources = [ + { + 'kind': 'reply_message_chain', + 'plugin': {'author': 'tester', 'name': 'demo'}, + } + ] + + async def emit_event_response(event_context, include_plugins=None): + return { + 'event_context': event_context, + 'emitted_plugins': [], + 'response_sources': response_sources, + } + + connector.handler = AsyncMock() + connector.handler.emit_event = AsyncMock(side_effect=emit_event_response) + + fake_event_ctx = Mock() + event_dump = event.model_dump() + event_dump['event_name'] = 'FriendMessage' + fake_event_ctx.model_dump.return_value = { + 'query_id': query.query_id, + 'eid': 0, + 'event_name': 'FriendMessage', + 'event': event_dump, + 'is_prevent_default': False, + 'is_prevent_postorder': False, + } + connector_module.context.EventContext.from_event = Mock(return_value=fake_event_ctx) + parsed_event_ctx = Mock() + connector_module.context.EventContext.model_validate = Mock(return_value=parsed_event_ctx) + try: + event_ctx = await connector.emit_event(event) + finally: + connector_module.context.EventContext.from_event = original_from_event + connector_module.context.EventContext.model_validate = original_model_validate + + assert event_ctx is parsed_event_ctx + assert event_ctx._response_sources == response_sources + + @pytest.mark.asyncio + async def test_emit_event_leaves_response_sources_absent_for_old_runtime(self): + connector = create_mock_connector() + query = text_query('hello') + event = query.message_event + object.__setattr__(event, 'query', query) + connector_module = get_connector_module() + original_from_event = connector_module.context.EventContext.from_event + original_model_validate = connector_module.context.EventContext.model_validate + + async def emit_event_response(event_context, include_plugins=None): + return { + 'event_context': event_context, + 'emitted_plugins': [ + {'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}}, + ], + } + + connector.handler = AsyncMock() + connector.handler.emit_event = AsyncMock(side_effect=emit_event_response) + + fake_event_ctx = Mock() + event_dump = event.model_dump() + event_dump['event_name'] = 'FriendMessage' + fake_event_ctx.model_dump.return_value = { + 'query_id': query.query_id, + 'eid': 0, + 'event_name': 'FriendMessage', + 'event': event_dump, + 'is_prevent_default': False, + 'is_prevent_postorder': False, + } + connector_module.context.EventContext.from_event = Mock(return_value=fake_event_ctx) + parsed_event_ctx = Mock() + connector_module.context.EventContext.model_validate = Mock(return_value=parsed_event_ctx) + try: + event_ctx = await connector.emit_event(event) + finally: + connector_module.context.EventContext.from_event = original_from_event + connector_module.context.EventContext.model_validate = original_model_validate + + assert '_response_sources' not in vars(event_ctx) + assert event_ctx._emitted_plugins == [ + {'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}}, + ] + + @pytest.mark.asyncio + async def test_notify_plugin_diagnostic_skips_when_disabled(self): + connector_module = get_connector_module() + + async def mock_disconnect(conn): + pass + + mock_app = create_mock_app() + mock_app.instance_config.data = {'plugin': {'enable': False}} + connector = connector_module.PluginRuntimeConnector(mock_app, mock_disconnect) + connector.handler = AsyncMock() + + await connector.notify_plugin_diagnostic({'code': 'response_delivery_failed'}) + + connector.handler.notify_plugin_diagnostic.assert_not_called() + + @pytest.mark.asyncio + async def test_notify_plugin_diagnostic_is_best_effort(self): + connector = create_mock_connector() + connector.handler = AsyncMock() + connector.handler.notify_plugin_diagnostic = AsyncMock(side_effect=RuntimeError('action not found')) + + await connector.notify_plugin_diagnostic({'code': 'response_delivery_failed'}) + + connector.handler.notify_plugin_diagnostic.assert_awaited_once() + connector.ap.logger.debug.assert_called_once() + + class TestListKnowledgeEngines: """Tests for list_knowledge_engines method.""" diff --git a/tests/unit_tests/plugin/test_handler.py b/tests/unit_tests/plugin/test_handler.py index 989a333a4..a2fdddd33 100644 --- a/tests/unit_tests/plugin/test_handler.py +++ b/tests/unit_tests/plugin/test_handler.py @@ -159,6 +159,36 @@ class TestHandlerRagErrorResponse: assert 'KeyError' in response.message +class TestHandlerPluginDiagnostic: + @pytest.mark.asyncio + async def test_notify_plugin_diagnostic_falls_back_to_raw_protocol_action(self): + """Diagnostic forwarding works before the SDK enum exists.""" + app = SimpleNamespace() + app.logger = SimpleNamespace(debug=MagicMock()) + runtime_handler = make_handler(app) + runtime_handler.call_action = AsyncMock(return_value={}) + + payload = {'code': 'response_delivery_failed'} + await runtime_handler.notify_plugin_diagnostic(payload) + + action = runtime_handler.call_action.await_args.args[0] + assert action.value == 'plugin_diagnostic' + assert runtime_handler.call_action.await_args.args[1] is payload + assert runtime_handler.call_action.await_args.kwargs['timeout'] == 5 + + def test_langbot_to_runtime_action_uses_enum_when_available(self): + """The compatibility helper should prefer SDK enums once available.""" + from langbot.pkg.plugin import handler as plugin_handler + + sentinel = object() + original = plugin_handler.LangBotToRuntimeAction + plugin_handler.LangBotToRuntimeAction = SimpleNamespace(PLUGIN_DIAGNOSTIC=sentinel) + try: + assert plugin_handler._langbot_to_runtime_action('PLUGIN_DIAGNOSTIC', 'plugin_diagnostic') is sentinel + finally: + plugin_handler.LangBotToRuntimeAction = original + + class TestConstantsSemanticVersion: """Tests for version constant access."""