Files
LangBot/tests/unit_tests/pipeline/test_wrapper.py
T
彼方 48905ea080 feat(plugin): report deferred response delivery failures (#2287)
* feat(plugin): report deferred response delivery failures

* style: fix ruff format issues in plugin_diagnostics and test_handler_actions

---------

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
2026-06-26 23:45:10 +08:00

617 lines
22 KiB
Python

"""
Unit tests for ResponseWrapper (wrapper) pipeline stage.
Tests cover:
- MessageChain wrapping
- Command response wrapping
- Plugin response wrapping
- Assistant response wrapping with content/tool_calls
- Plugin event emission and INTERRUPT handling
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock, AsyncMock
from importlib import import_module
from tests.factories import (
FakeApp,
text_query,
)
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
def get_wrapper_module():
"""Lazy import to avoid circular import issues."""
# Import pipelinemgr first to trigger stage registration
import_module('langbot.pkg.pipeline.pipelinemgr')
return import_module('langbot.pkg.pipeline.wrapper.wrapper')
def get_entities_module():
"""Lazy import for pipeline entities."""
return import_module('langbot.pkg.pipeline.entities')
def get_plugin_diagnostics_module():
"""Lazy import for plugin diagnostic attribution helpers."""
return import_module('langbot.pkg.pipeline.plugin_diagnostics')
def make_wrapper_config():
"""Create a pipeline config for wrapper tests."""
return {
'output': {
'misc': {
'at-sender': False,
'quote-origin': False,
'track-function-calls': False,
}
}
}
def make_session():
"""Create a valid Session object for tests."""
return provider_session.Session(
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
use_prompt_name='default',
using_conversation=None,
conversations=[],
)
class TestResponseWrapperInit:
"""Tests for ResponseWrapper initialization."""
@pytest.mark.asyncio
async def test_initialize_passes(self):
"""Initialize should complete without error."""
wrapper = get_wrapper_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
pipeline_config = {}
await stage.initialize(pipeline_config)
class TestResponseWrapperMessageChain:
"""Tests for MessageChain wrapping."""
@pytest.mark.asyncio
async def test_message_chain_direct_append(self):
"""MessageChain in resp_messages should be directly appended."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])]
query.resp_message_chain = []
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
assert len(results[0].new_query.resp_message_chain) == 1
@pytest.mark.asyncio
async def test_message_chain_direct_append_consumes_pending_plugin_source(self):
"""MessageChain replies from earlier plugin events keep attribution."""
wrapper = get_wrapper_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
await stage.initialize(make_wrapper_config())
reply_chain = platform_message.MessageChain([platform_message.Plain(text='response')])
query = text_query('hello')
query.pipeline_config = make_wrapper_config()
query.resp_messages = [reply_chain]
query.resp_message_chain = []
plugin_diagnostics = get_plugin_diagnostics_module()
plugin_diagnostics.record_pending_plugin_response_source(
query,
reply_chain,
[
{
'kind': 'reply_message_chain',
'plugin': {'author': 'tester', 'name': 'demo'},
}
],
[{'manifest': {'metadata': {'author': 'observer', 'name': 'not-reply-source'}}}],
'PersonNormalMessageReceived',
)
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0)
assert sources[0].plugin == {'author': 'tester', 'name': 'demo'}
assert sources[0].event_name == 'PersonNormalMessageReceived'
assert sources[0].is_approximate is False
assert '_plugin_response_sources' not in query.variables
assert '_plugin_pending_response_sources' not in query.variables
class TestResponseWrapperCommand:
"""Tests for command response wrapping."""
@pytest.mark.asyncio
async def test_command_response_prefix(self):
"""Command response should have [bot] prefix."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create a command response message
command_resp = Mock()
command_resp.role = 'command'
command_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')])
)
query.resp_messages = [command_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
# Check that prefix was added (via get_content_platform_message_chain)
command_resp.get_content_platform_message_chain.assert_called_once()
class TestResponseWrapperPlugin:
"""Tests for plugin response wrapping."""
@pytest.mark.asyncio
async def test_plugin_response_direct(self):
"""Plugin response should be wrapped without prefix."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create a plugin response message
plugin_resp = Mock()
plugin_resp.role = 'plugin'
plugin_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')])
)
query.resp_messages = [plugin_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
class TestResponseWrapperAssistant:
"""Tests for assistant response wrapping."""
@pytest.mark.asyncio
async def test_assistant_content_response(self):
"""Assistant with content should emit event and wrap."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector - normal event (not prevented)
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = None
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Hello back!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
# Event should have been emitted
app.plugin_connector.emit_event.assert_called()
@pytest.mark.asyncio
async def test_assistant_empty_content(self):
"""Assistant with empty content should not emit event."""
wrapper = get_wrapper_module()
app = FakeApp()
app.plugin_connector.emit_event = AsyncMock()
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with empty content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = None
assistant_resp.tool_calls = None
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert results == []
assert query.resp_message_chain == []
app.plugin_connector.emit_event.assert_not_called()
@pytest.mark.asyncio
async def test_assistant_tool_calls(self):
"""Assistant with tool_calls should show function call message."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = None
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
pipeline_config['output']['misc']['track-function-calls'] = True
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with tool_calls
mock_tool_call = Mock()
mock_tool_call.function = Mock()
mock_tool_call.function.name = 'test_function'
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Processing...'
assistant_resp.tool_calls = [mock_tool_call]
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 2
for result in results:
assert result.result_type == entities.ResultType.CONTINUE
assert app.plugin_connector.emit_event.await_count == 2
class TestResponseWrapperInterrupt:
"""Tests for INTERRUPT behavior when plugin prevents default."""
@pytest.mark.asyncio
async def test_event_prevented_interrupts(self):
"""Plugin event prevented should return INTERRUPT."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector - event is prevented
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=True)
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Hello!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
class TestResponseWrapperCustomReply:
"""Tests for custom reply from plugin event."""
@pytest.mark.asyncio
async def test_custom_reply_chain_used(self):
"""Plugin reply_message_chain should replace default."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector with custom reply
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = custom_chain
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
# Custom chain should be in resp_message_chain
assert len(results[0].new_query.resp_message_chain) == 1
# Should be the custom chain
chain = results[0].new_query.resp_message_chain[0]
assert 'Custom reply' in str(chain)
@pytest.mark.asyncio
async def test_custom_reply_records_plugin_source(self):
"""Plugin reply_message_chain should keep emitted plugin attribution."""
wrapper = get_wrapper_module()
app = FakeApp()
app.sess_mgr.get_session = AsyncMock(return_value=make_session())
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = custom_chain
mock_event_ctx._emitted_plugins = [
{
'manifest': {'metadata': {'author': 'observer', 'name': 'not-reply-source'}},
'plugin_config': {'token': 'secret-token'},
},
]
mock_event_ctx._response_sources = [
{
'kind': 'reply_message_chain',
'plugin': {'author': 'tester', 'name': 'demo'},
}
]
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
plugin_diagnostics = get_plugin_diagnostics_module()
sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0)
assert sources[0].plugin == {'author': 'tester', 'name': 'demo'}
assert sources[0].event_name == 'NormalMessageResponded'
assert sources[0].is_approximate is False
assert 'secret-token' not in str(sources)
assert '_plugin_response_sources' not in query.variables
@pytest.mark.asyncio
async def test_custom_reply_falls_back_to_emitted_plugins_for_old_runtime(self):
"""Older plugin runtimes without response_sources keep approximate attribution."""
wrapper = get_wrapper_module()
app = FakeApp()
app.sess_mgr.get_session = AsyncMock(return_value=make_session())
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = custom_chain
mock_event_ctx._emitted_plugins = [
{'manifest': {'metadata': {'author': 'tester', 'name': 'demo'}}},
]
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
plugin_diagnostics = get_plugin_diagnostics_module()
sources = plugin_diagnostics._get_response_sources(results[0].new_query, 0)
assert sources[0].plugin == {'author': 'tester', 'name': 'demo'}
assert sources[0].is_approximate is True
class TestResponseWrapperVariables:
"""Tests for bound plugins variable."""
@pytest.mark.asyncio
async def test_bound_plugins_passed_to_event(self):
"""_pipeline_bound_plugins should be passed to emit_event."""
wrapper = get_wrapper_module()
get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = None
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = 'Hello'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
# Check that bound_plugins was passed
emit_call = app.plugin_connector.emit_event.call_args
assert emit_call[0][1] == ['plugin1', 'plugin2'] # Second argument is bound_plugins