feat(plugin): report deferred response delivery failures (#2287)

* feat(plugin): report deferred response delivery failures

* style: fix ruff format issues in plugin_diagnostics and test_handler_actions

---------

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
This commit is contained in:
彼方
2026-06-26 23:45:10 +08:00
committed by GitHub
parent ddb77fc43c
commit 48905ea080
12 changed files with 785 additions and 29 deletions
@@ -0,0 +1,274 @@
from __future__ import annotations
import traceback
import weakref
from dataclasses import dataclass, field
from typing import Any
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
@dataclass(frozen=True)
class PluginResponseSource:
plugin: dict[str, str]
event_name: str | None = None
is_approximate: bool = False
@dataclass
class QueryDiagnosticState:
pending_by_chain_id: dict[int, list[PluginResponseSource]] = field(default_factory=dict)
by_response_index: dict[int, list[PluginResponseSource]] = field(default_factory=dict)
finalizer: weakref.finalize | None = None
_QUERY_STATES: dict[int, QueryDiagnosticState] = {}
def record_plugin_response_source(
query: pipeline_query.Query,
response_index: int,
response_sources: list[dict[str, Any]] | None,
emitted_plugins: list[dict[str, Any]] | None = None,
event_name: str | None = None,
) -> None:
plugin_sources = _build_plugin_sources(response_sources, emitted_plugins, event_name)
if not plugin_sources:
return
state = _get_or_create_query_state(query)
state.by_response_index[response_index] = plugin_sources
def record_last_plugin_response_source(
query: pipeline_query.Query,
response_sources: list[dict[str, Any]] | None,
emitted_plugins: list[dict[str, Any]] | None = None,
event_name: str | None = None,
) -> None:
record_plugin_response_source(
query,
len(query.resp_message_chain) - 1,
response_sources,
emitted_plugins,
event_name,
)
def record_pending_plugin_response_source(
query: pipeline_query.Query,
message_chain: platform_message.MessageChain,
response_sources: list[dict[str, Any]] | None,
emitted_plugins: list[dict[str, Any]] | None = None,
event_name: str | None = None,
) -> None:
plugin_sources = _build_plugin_sources(response_sources, emitted_plugins, event_name)
if not plugin_sources:
return
state = _get_or_create_query_state(query)
state.pending_by_chain_id[id(message_chain)] = plugin_sources
def consume_pending_plugin_response_source(
query: pipeline_query.Query,
message_chain: platform_message.MessageChain,
response_index: int,
) -> None:
state = _get_query_state(query)
if state is None:
return
source = state.pending_by_chain_id.pop(id(message_chain), None)
if source is None:
return
state.by_response_index[response_index] = source
def clear_response_source(query: pipeline_query.Query, response_index: int) -> None:
state = _get_query_state(query)
if state is None:
return
state.by_response_index.pop(response_index, None)
_discard_query_state_if_empty(query)
async def notify_response_delivery_failure(
ap: Any,
query: pipeline_query.Query,
response_index: int,
message_chain: platform_message.MessageChain,
error: Exception,
) -> None:
try:
plugin_refs = _get_response_sources(query, response_index)
if not plugin_refs:
return
connector = getattr(ap, 'plugin_connector', None)
if connector is None or not hasattr(connector, 'notify_plugin_diagnostic'):
return
for source in plugin_refs:
payload = _build_delivery_failure_payload(
plugin_ref=source.plugin,
event_name=source.event_name,
is_approximate=source.is_approximate,
query=query,
response_index=response_index,
message_chain=message_chain,
error=error,
)
try:
await connector.notify_plugin_diagnostic(payload)
except Exception as diag_error:
_debug(ap, f'Plugin diagnostic forwarding failed: {diag_error}')
except Exception as diag_error:
_debug(ap, f'Plugin diagnostic forwarding skipped: {diag_error}')
def get_emitted_plugins(event_ctx: Any) -> list[dict[str, Any]]:
emitted_plugins = getattr(event_ctx, '_emitted_plugins', [])
return emitted_plugins if isinstance(emitted_plugins, list) else []
def get_response_sources(event_ctx: Any) -> list[dict[str, Any]] | None:
event_attrs = vars(event_ctx)
if '_response_sources' not in event_attrs:
return None
response_sources = event_attrs['_response_sources']
return response_sources if isinstance(response_sources, list) else []
def _get_or_create_query_state(query: pipeline_query.Query) -> QueryDiagnosticState:
query_key = id(query)
state = _QUERY_STATES.get(query_key)
if state is not None:
return state
state = QueryDiagnosticState()
try:
state.finalizer = weakref.finalize(query, _discard_query_state, query_key)
except TypeError:
state.finalizer = None
_QUERY_STATES[query_key] = state
return state
def _get_query_state(query: pipeline_query.Query) -> QueryDiagnosticState | None:
return _QUERY_STATES.get(id(query))
def _discard_query_state(query_key: int) -> None:
_QUERY_STATES.pop(query_key, None)
def _discard_query_state_if_empty(query: pipeline_query.Query) -> None:
query_key = id(query)
state = _QUERY_STATES.get(query_key)
if state is None:
return
if state.pending_by_chain_id or state.by_response_index:
return
if state.finalizer is not None:
state.finalizer.detach()
_discard_query_state(query_key)
def _get_response_sources(
query: pipeline_query.Query,
response_index: int,
) -> list[PluginResponseSource]:
state = _get_query_state(query)
if state is None:
return []
return state.by_response_index.get(response_index, [])
def _extract_plugin_ref(plugin: Any) -> dict[str, str] | None:
manifest = plugin.get('manifest') if isinstance(plugin, dict) else None
metadata = manifest.get('metadata') if isinstance(manifest, dict) else None
if not isinstance(metadata, dict):
return None
author = metadata.get('author')
name = metadata.get('name')
if not author or not name:
return None
return {'author': str(author), 'name': str(name)}
def _extract_response_source_plugin_ref(source: Any) -> dict[str, str] | None:
if not isinstance(source, dict):
return None
if source.get('kind') != 'reply_message_chain':
return None
plugin_ref = source.get('plugin')
if not isinstance(plugin_ref, dict):
return None
author = plugin_ref.get('author')
name = plugin_ref.get('name')
if not author or not name:
return None
return {'author': str(author), 'name': str(name)}
def _build_plugin_sources(
response_sources: list[dict[str, Any]] | None,
emitted_plugins: list[dict[str, Any]] | None,
event_name: str | None,
) -> list[PluginResponseSource]:
if response_sources is not None:
plugin_refs = [_extract_response_source_plugin_ref(source) for source in response_sources]
return [
PluginResponseSource(plugin=plugin, event_name=event_name) for plugin in plugin_refs if plugin is not None
]
if emitted_plugins:
plugin_refs = [_extract_plugin_ref(plugin) for plugin in emitted_plugins]
return [
PluginResponseSource(plugin=plugin, event_name=event_name, is_approximate=True)
for plugin in plugin_refs
if plugin is not None
]
return []
def _debug(ap: Any, message: str) -> None:
logger = getattr(ap, 'logger', None)
if logger is not None:
logger.debug(message)
def _build_delivery_failure_payload(
plugin_ref: dict[str, str],
event_name: str | None,
is_approximate: bool,
query: pipeline_query.Query,
response_index: int,
message_chain: platform_message.MessageChain,
error: Exception,
) -> dict[str, Any]:
details: dict[str, Any] = {
'message_component_types': [component.__class__.__name__ for component in message_chain],
'message_preview': str(message_chain)[:200],
}
if is_approximate:
details['attribution_warning'] = (
'This diagnostic was delivered to all plugins that handled the event because the '
'plugin runtime did not report the exact reply_message_chain source.'
)
return {
'level': 'ERROR',
'code': 'response_delivery_failed',
'message': 'Failed to deliver a plugin-provided response message.',
'plugin': plugin_ref,
'query': {
'query_id': query.query_id,
'event_name': event_name or query.message_event.__class__.__name__,
'stage': query.current_stage_name or 'SendResponseBackStage',
'response_index': response_index,
},
'details': details,
'delivery': {
'error_type': error.__class__.__name__,
'error_message': str(error),
'traceback': traceback.format_exception_only(type(error), error)[-1].strip(),
},
}
@@ -9,6 +9,7 @@ from datetime import datetime
from .. import handler
from ... import entities
from ... import plugin_diagnostics
from ....provider import runner as runner_module
import langbot_plugin.api.entities.events as events
@@ -58,6 +59,13 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is not None:
mc = event_ctx.event.reply_message_chain
plugin_diagnostics.record_pending_plugin_response_source(
query,
mc,
plugin_diagnostics.get_response_sources(event_ctx),
plugin_diagnostics.get_emitted_plugins(event_ctx),
event.event_name,
)
query.resp_messages.append(mc)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
@@ -4,6 +4,7 @@ import typing
from .. import handler
from ... import entities
from ... import plugin_diagnostics
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@@ -52,6 +53,13 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is not None:
mc = event_ctx.event.reply_message_chain
plugin_diagnostics.record_pending_plugin_response_source(
query,
mc,
plugin_diagnostics.get_response_sources(event_ctx),
plugin_diagnostics.get_emitted_plugins(event_ctx),
event.event_name,
)
query.resp_messages.append(mc)
+30 -14
View File
@@ -9,6 +9,7 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from .. import stage, entities
from .. import plugin_diagnostics
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@@ -39,20 +40,35 @@ class SendResponseBackStage(stage.PipelineStage):
has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
# TODO 命令与流式的兼容性问题
if await query.adapter.is_stream_output_supported() and has_chunks:
is_final = [msg.is_final for msg in query.resp_messages][0]
await query.adapter.reply_message_chunk(
message_source=query.message_event,
bot_message=query.resp_messages[-1],
message=query.resp_message_chain[-1],
quote_origin=quote_origin,
is_final=is_final,
)
else:
await query.adapter.reply_message(
message_source=query.message_event,
message=query.resp_message_chain[-1],
quote_origin=quote_origin,
response_index = len(query.resp_message_chain) - 1
message_chain = query.resp_message_chain[-1]
try:
if await query.adapter.is_stream_output_supported() and has_chunks:
is_final = [msg.is_final for msg in query.resp_messages][0]
await query.adapter.reply_message_chunk(
message_source=query.message_event,
bot_message=query.resp_messages[-1],
message=message_chain,
quote_origin=quote_origin,
is_final=is_final,
)
else:
await query.adapter.reply_message(
message_source=query.message_event,
message=message_chain,
quote_origin=quote_origin,
)
except Exception as e:
await plugin_diagnostics.notify_response_delivery_failure(
self.ap,
query,
response_index,
message_chain,
e,
)
plugin_diagnostics.clear_response_source(query, response_index)
raise
plugin_diagnostics.clear_response_source(query, response_index)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
@@ -3,6 +3,7 @@ from __future__ import annotations
import typing
from .. import entities
from .. import plugin_diagnostics
from .. import stage
import langbot_plugin.api.entities.builtin.platform.message as platform_message
@@ -78,6 +79,11 @@ class ResponseWrapper(stage.PipelineStage):
# 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
plugin_diagnostics.consume_pending_plugin_response_source(
query,
query.resp_messages[-1],
len(query.resp_message_chain) - 1,
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
@@ -129,8 +135,10 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply_message_chain is not None:
reply_chain = event_ctx.event.reply_message_chain
is_plugin_reply = True
else:
reply_chain = result.get_content_platform_message_chain()
is_plugin_reply = False
# Attach files the agent produced in the sandbox
# outbox, but only on the terminal assistant message.
@@ -138,6 +146,13 @@ class ResponseWrapper(stage.PipelineStage):
await self._append_outbound_attachments(query, reply_chain)
query.resp_message_chain.append(reply_chain)
if is_plugin_reply:
plugin_diagnostics.record_last_plugin_response_source(
query,
plugin_diagnostics.get_response_sources(event_ctx),
plugin_diagnostics.get_emitted_plugins(event_ctx),
event.event_name,
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@@ -180,6 +195,12 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply_message_chain is not None:
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
plugin_diagnostics.record_last_plugin_response_source(
query,
plugin_diagnostics.get_response_sources(event_ctx),
plugin_diagnostics.get_emitted_plugins(event_ctx),
event.event_name,
)
else:
query.resp_message_chain.append(
+14
View File
@@ -737,6 +737,8 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
event_ctx = context.EventContext.from_event(event)
if not self.is_enable_plugin:
event_ctx._emitted_plugins = []
event_ctx._response_sources = []
return event_ctx
# Pass include_plugins to runtime for filtering
@@ -745,9 +747,21 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
)
event_ctx = context.EventContext.model_validate(event_ctx_result['event_context'])
event_ctx._emitted_plugins = event_ctx_result.get('emitted_plugins', [])
if 'response_sources' in event_ctx_result:
event_ctx._response_sources = event_ctx_result['response_sources']
return event_ctx
async def notify_plugin_diagnostic(self, diagnostic: dict[str, Any]) -> None:
"""Best-effort diagnostic forwarding to the plugin runtime."""
if not self.is_enable_plugin:
return
try:
await self.handler.notify_plugin_diagnostic(diagnostic)
except Exception as e:
self.ap.logger.debug(f'Plugin diagnostic forwarding skipped: {e}')
async def list_tools(self, bound_plugins: list[str] | None = None) -> list[ComponentManifest]:
if not self.is_enable_plugin:
return []
+21
View File
@@ -26,6 +26,15 @@ from ..core import app
from ..utils import constants
class _RawAction:
def __init__(self, value: str):
self.value = value
def _langbot_to_runtime_action(enum_name: str, fallback_value: str) -> Any:
return getattr(LangBotToRuntimeAction, enum_name, _RawAction(fallback_value))
def _make_rag_error_response(error: Exception, error_type: str, **extra_context) -> handler.ActionResponse:
"""Create a clean error response for RAG operations.
@@ -923,6 +932,18 @@ class RuntimeConnectionHandler(handler.Handler):
return result
async def notify_plugin_diagnostic(self, diagnostic: dict[str, Any]) -> dict[str, Any]:
"""Notify the plugin runtime about a best-effort plugin diagnostic.
This intentionally uses the raw protocol string instead of a SDK enum so
LangBot can keep running with older langbot-plugin versions.
"""
return await self.call_action(
_langbot_to_runtime_action('PLUGIN_DIAGNOSTIC', 'plugin_diagnostic'),
diagnostic,
timeout=5,
)
async def list_tools(self, include_plugins: list[str] | None = None) -> list[dict[str, Any]]:
"""List tools"""
result = await self.call_action(
@@ -662,6 +662,100 @@ class TestSendResponseBackStage:
assert len(outbound) == 1
assert outbound[0]['type'] == 'reply'
@pytest.mark.asyncio
async def test_send_response_failure_notifies_plugin_diagnostic(self, pipeline_app):
"""Plugin-provided deferred replies should report delivery failures."""
from langbot.pkg.pipeline import plugin_diagnostics
from langbot.pkg.pipeline.respback import respback
from tests.factories.message import text_chain
from langbot_plugin.api.entities.builtin.provider.message import Message
query = text_query('hello')
query.adapter.reply_message.side_effect = RuntimeError('send failed')
query.pipeline_config = create_minimal_pipeline_config()
query.current_stage_name = 'SendResponseBackStage'
query.resp_messages = [Message(role='assistant', content='test response')]
query.resp_message_chain = [text_chain('test response')]
plugin_diagnostics.record_plugin_response_source(
query,
0,
[
{
'kind': 'reply_message_chain',
'plugin': {'author': 'tester', 'name': 'demo'},
}
],
[{'manifest': {'metadata': {'author': 'observer', 'name': 'not-reply-source'}}}],
'NormalMessageResponded',
)
pipeline_app.plugin_connector.notify_plugin_diagnostic = AsyncMock()
respback_stage = respback.SendResponseBackStage(pipeline_app)
with pytest.raises(RuntimeError, match='send failed'):
await respback_stage.process(query, 'SendResponseBackStage')
pipeline_app.plugin_connector.notify_plugin_diagnostic.assert_awaited_once()
payload = pipeline_app.plugin_connector.notify_plugin_diagnostic.await_args.args[0]
assert payload['code'] == 'response_delivery_failed'
assert payload['plugin'] == {'author': 'tester', 'name': 'demo'}
assert payload['query']['event_name'] == 'NormalMessageResponded'
assert payload['delivery']['error_type'] == 'RuntimeError'
assert 'attribution_warning' not in payload['details']
@pytest.mark.asyncio
async def test_send_response_failure_warns_for_old_runtime_attribution(self, pipeline_app):
"""Older plugin runtimes without response_sources should get approximate diagnostics."""
from langbot.pkg.pipeline import plugin_diagnostics
from langbot.pkg.pipeline.respback import respback
from tests.factories.message import text_chain
from langbot_plugin.api.entities.builtin.provider.message import Message
query = text_query('hello')
query.adapter.reply_message.side_effect = RuntimeError('send failed')
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = [Message(role='assistant', content='test response')]
query.resp_message_chain = [text_chain('test response')]
plugin_diagnostics.record_plugin_response_source(
query,
0,
None,
[{'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}}],
'NormalMessageResponded',
)
pipeline_app.plugin_connector.notify_plugin_diagnostic = AsyncMock()
respback_stage = respback.SendResponseBackStage(pipeline_app)
with pytest.raises(RuntimeError, match='send failed'):
await respback_stage.process(query, 'SendResponseBackStage')
payload = pipeline_app.plugin_connector.notify_plugin_diagnostic.await_args.args[0]
assert payload['plugin'] == {'author': 'tester', 'name': 'demo'}
assert 'attribution_warning' in payload['details']
@pytest.mark.asyncio
async def test_send_response_failure_ignores_query_variable_spoofing(self, pipeline_app):
"""Plugin-controlled query variables must not mask delivery failures."""
from langbot.pkg.pipeline.respback import respback
from tests.factories.message import text_chain
from langbot_plugin.api.entities.builtin.provider.message import Message
query = text_query('hello')
query.adapter.reply_message.side_effect = RuntimeError('send failed')
query.pipeline_config = create_minimal_pipeline_config()
query.resp_messages = [Message(role='assistant', content='test response')]
query.resp_message_chain = [text_chain('test response')]
query.variables['_plugin_response_sources'] = {0: ['malformed']}
pipeline_app.plugin_connector.notify_plugin_diagnostic = AsyncMock()
respback_stage = respback.SendResponseBackStage(pipeline_app)
with pytest.raises(RuntimeError, match='send failed'):
await respback_stage.process(query, 'SendResponseBackStage')
pipeline_app.plugin_connector.notify_plugin_diagnostic.assert_not_called()
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestStageChainIntegration:
+142
View File
@@ -36,6 +36,11 @@ def get_entities_module():
return import_module('langbot.pkg.pipeline.entities')
def get_plugin_diagnostics_module():
"""Lazy import for plugin diagnostic attribution helpers."""
return import_module('langbot.pkg.pipeline.plugin_diagnostics')
def make_wrapper_config():
"""Create a pipeline config for wrapper tests."""
return {
@@ -106,6 +111,45 @@ class TestResponseWrapperMessageChain:
assert results[0].result_type == entities.ResultType.CONTINUE
assert len(results[0].new_query.resp_message_chain) == 1
@pytest.mark.asyncio
async def test_message_chain_direct_append_consumes_pending_plugin_source(self):
"""MessageChain replies from earlier plugin events keep attribution."""
wrapper = get_wrapper_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
await stage.initialize(make_wrapper_config())
reply_chain = platform_message.MessageChain([platform_message.Plain(text='response')])
query = text_query('hello')
query.pipeline_config = make_wrapper_config()
query.resp_messages = [reply_chain]
query.resp_message_chain = []
plugin_diagnostics = get_plugin_diagnostics_module()
plugin_diagnostics.record_pending_plugin_response_source(
query,
reply_chain,
[
{
'kind': 'reply_message_chain',
'plugin': {'author': 'tester', 'name': 'demo'},
}
],
[{'manifest': {'metadata': {'author': 'observer', 'name': 'not-reply-source'}}}],
'PersonNormalMessageReceived',
)
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0)
assert sources[0].plugin == {'author': 'tester', 'name': 'demo'}
assert sources[0].event_name == 'PersonNormalMessageReceived'
assert sources[0].is_approximate is False
assert '_plugin_response_sources' not in query.variables
assert '_plugin_pending_response_sources' not in query.variables
class TestResponseWrapperCommand:
"""Tests for command response wrapping."""
@@ -421,6 +465,104 @@ class TestResponseWrapperCustomReply:
chain = results[0].new_query.resp_message_chain[0]
assert 'Custom reply' in str(chain)
@pytest.mark.asyncio
async def test_custom_reply_records_plugin_source(self):
"""Plugin reply_message_chain should keep emitted plugin attribution."""
wrapper = get_wrapper_module()
app = FakeApp()
app.sess_mgr.get_session = AsyncMock(return_value=make_session())
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = custom_chain
mock_event_ctx._emitted_plugins = [
{
'manifest': {'metadata': {'author': 'observer', 'name': 'not-reply-source'}},
'plugin_config': {'token': 'secret-token'},
},
]
mock_event_ctx._response_sources = [
{
'kind': 'reply_message_chain',
'plugin': {'author': 'tester', 'name': 'demo'},
}
]
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
plugin_diagnostics = get_plugin_diagnostics_module()
sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0)
assert sources[0].plugin == {'author': 'tester', 'name': 'demo'}
assert sources[0].event_name == 'NormalMessageResponded'
assert sources[0].is_approximate is False
assert 'secret-token' not in str(sources)
assert '_plugin_response_sources' not in query.variables
@pytest.mark.asyncio
async def test_custom_reply_falls_back_to_emitted_plugins_for_old_runtime(self):
"""Older plugin runtimes without response_sources keep approximate attribution."""
wrapper = get_wrapper_module()
app = FakeApp()
app.sess_mgr.get_session = AsyncMock(return_value=make_session())
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = custom_chain
mock_event_ctx._emitted_plugins = [
{'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}},
]
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
plugin_diagnostics = get_plugin_diagnostics_module()
sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0)
assert sources[0].plugin == {'author': 'tester', 'name': 'demo'}
assert sources[0].is_approximate is True
class TestResponseWrapperVariables:
"""Tests for bound plugins variable."""
@@ -13,6 +13,8 @@ import pytest
from unittest.mock import Mock, AsyncMock
from importlib import import_module
from tests.factories import text_query
def get_connector_module():
"""Lazy import to avoid circular import issues."""
@@ -132,6 +134,130 @@ class TestListPlugins:
assert result[0]['debug'] is True
class TestPluginDiagnostics:
@pytest.mark.asyncio
async def test_emit_event_preserves_response_sources(self):
connector = create_mock_connector()
query = text_query('hello')
event = query.message_event
object.__setattr__(event, 'query', query)
connector_module = get_connector_module()
original_from_event = connector_module.context.EventContext.from_event
original_model_validate = connector_module.context.EventContext.model_validate
response_sources = [
{
'kind': 'reply_message_chain',
'plugin': {'author': 'tester', 'name': 'demo'},
}
]
async def emit_event_response(event_context, include_plugins=None):
return {
'event_context': event_context,
'emitted_plugins': [],
'response_sources': response_sources,
}
connector.handler = AsyncMock()
connector.handler.emit_event = AsyncMock(side_effect=emit_event_response)
fake_event_ctx = Mock()
event_dump = event.model_dump()
event_dump['event_name'] = 'FriendMessage'
fake_event_ctx.model_dump.return_value = {
'query_id': query.query_id,
'eid': 0,
'event_name': 'FriendMessage',
'event': event_dump,
'is_prevent_default': False,
'is_prevent_postorder': False,
}
connector_module.context.EventContext.from_event = Mock(return_value=fake_event_ctx)
parsed_event_ctx = Mock()
connector_module.context.EventContext.model_validate = Mock(return_value=parsed_event_ctx)
try:
event_ctx = await connector.emit_event(event)
finally:
connector_module.context.EventContext.from_event = original_from_event
connector_module.context.EventContext.model_validate = original_model_validate
assert event_ctx is parsed_event_ctx
assert event_ctx._response_sources == response_sources
@pytest.mark.asyncio
async def test_emit_event_leaves_response_sources_absent_for_old_runtime(self):
connector = create_mock_connector()
query = text_query('hello')
event = query.message_event
object.__setattr__(event, 'query', query)
connector_module = get_connector_module()
original_from_event = connector_module.context.EventContext.from_event
original_model_validate = connector_module.context.EventContext.model_validate
async def emit_event_response(event_context, include_plugins=None):
return {
'event_context': event_context,
'emitted_plugins': [
{'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}},
],
}
connector.handler = AsyncMock()
connector.handler.emit_event = AsyncMock(side_effect=emit_event_response)
fake_event_ctx = Mock()
event_dump = event.model_dump()
event_dump['event_name'] = 'FriendMessage'
fake_event_ctx.model_dump.return_value = {
'query_id': query.query_id,
'eid': 0,
'event_name': 'FriendMessage',
'event': event_dump,
'is_prevent_default': False,
'is_prevent_postorder': False,
}
connector_module.context.EventContext.from_event = Mock(return_value=fake_event_ctx)
parsed_event_ctx = Mock()
connector_module.context.EventContext.model_validate = Mock(return_value=parsed_event_ctx)
try:
event_ctx = await connector.emit_event(event)
finally:
connector_module.context.EventContext.from_event = original_from_event
connector_module.context.EventContext.model_validate = original_model_validate
assert '_response_sources' not in vars(event_ctx)
assert event_ctx._emitted_plugins == [
{'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}},
]
@pytest.mark.asyncio
async def test_notify_plugin_diagnostic_skips_when_disabled(self):
connector_module = get_connector_module()
async def mock_disconnect(conn):
pass
mock_app = create_mock_app()
mock_app.instance_config.data = {'plugin': {'enable': False}}
connector = connector_module.PluginRuntimeConnector(mock_app, mock_disconnect)
connector.handler = AsyncMock()
await connector.notify_plugin_diagnostic({'code': 'response_delivery_failed'})
connector.handler.notify_plugin_diagnostic.assert_not_called()
@pytest.mark.asyncio
async def test_notify_plugin_diagnostic_is_best_effort(self):
connector = create_mock_connector()
connector.handler = AsyncMock()
connector.handler.notify_plugin_diagnostic = AsyncMock(side_effect=RuntimeError('action not found'))
await connector.notify_plugin_diagnostic({'code': 'response_delivery_failed'})
connector.handler.notify_plugin_diagnostic.assert_awaited_once()
connector.ap.logger.debug.assert_called_once()
class TestListKnowledgeEngines:
"""Tests for list_knowledge_engines method."""
+30
View File
@@ -159,6 +159,36 @@ class TestHandlerRagErrorResponse:
assert 'KeyError' in response.message
class TestHandlerPluginDiagnostic:
@pytest.mark.asyncio
async def test_notify_plugin_diagnostic_falls_back_to_raw_protocol_action(self):
"""Diagnostic forwarding works before the SDK enum exists."""
app = SimpleNamespace()
app.logger = SimpleNamespace(debug=MagicMock())
runtime_handler = make_handler(app)
runtime_handler.call_action = AsyncMock(return_value={})
payload = {'code': 'response_delivery_failed'}
await runtime_handler.notify_plugin_diagnostic(payload)
action = runtime_handler.call_action.await_args.args[0]
assert action.value == 'plugin_diagnostic'
assert runtime_handler.call_action.await_args.args[1] is payload
assert runtime_handler.call_action.await_args.kwargs['timeout'] == 5
def test_langbot_to_runtime_action_uses_enum_when_available(self):
"""The compatibility helper should prefer SDK enums once available."""
from langbot.pkg.plugin import handler as plugin_handler
sentinel = object()
original = plugin_handler.LangBotToRuntimeAction
plugin_handler.LangBotToRuntimeAction = SimpleNamespace(PLUGIN_DIAGNOSTIC=sentinel)
try:
assert plugin_handler._langbot_to_runtime_action('PLUGIN_DIAGNOSTIC', 'plugin_diagnostic') is sentinel
finally:
plugin_handler.LangBotToRuntimeAction = original
class TestConstantsSemanticVersion:
"""Tests for version constant access."""
+17 -15
View File
@@ -51,13 +51,15 @@ class TestRagRerankAction:
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=rerank_model)
runtime_handler = make_handler(app)
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value]({
'rerank_model_uuid': 'rerank-1',
'query': 'hello',
'documents': ['a', 'b'],
'top_k': 1,
'extra_args': {'return_documents': False},
})
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value](
{
'rerank_model_uuid': 'rerank-1',
'query': 'hello',
'documents': ['a', 'b'],
'top_k': 1,
'extra_args': {'return_documents': False},
}
)
assert response.code == 0
assert response.data['results'] == [{'index': 1, 'relevance_score': 0.9}]
@@ -72,16 +74,16 @@ class TestRagRerankAction:
@pytest.mark.asyncio
async def test_returns_error_when_rerank_model_missing(self, app):
"""Missing rerank model returns an action error."""
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
side_effect=ValueError('not found')
)
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(side_effect=ValueError('not found'))
runtime_handler = make_handler(app)
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value]({
'rerank_model_uuid': 'missing',
'query': 'hello',
'documents': ['a'],
})
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value](
{
'rerank_model_uuid': 'missing',
'query': 'hello',
'documents': ['a'],
}
)
assert response.code != 0
assert 'Rerank model with rerank_model_uuid missing not found' in response.message