Files
LangBot/tests/unit_tests/pipeline/test_wrapper.py
2026-05-16 10:30:17 +08:00

477 lines
16 KiB
Python

"""
Unit tests for ResponseWrapper (wrapper) pipeline stage.
Tests cover:
- MessageChain wrapping
- Command response wrapping
- Plugin response wrapping
- Assistant response wrapping with content/tool_calls
- Plugin event emission and INTERRUPT handling
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock, AsyncMock
from importlib import import_module
from tests.factories import (
FakeApp,
text_query,
)
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
def get_wrapper_module():
"""Lazy import to avoid circular import issues."""
# Import pipelinemgr first to trigger stage registration
import_module('langbot.pkg.pipeline.pipelinemgr')
return import_module('langbot.pkg.pipeline.wrapper.wrapper')
def get_entities_module():
"""Lazy import for pipeline entities."""
return import_module('langbot.pkg.pipeline.entities')
def make_wrapper_config():
"""Create a pipeline config for wrapper tests."""
return {
'output': {
'misc': {
'at-sender': False,
'quote-origin': False,
'track-function-calls': False,
}
}
}
def make_session():
"""Create a valid Session object for tests."""
return provider_session.Session(
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
use_prompt_name="default",
using_conversation=None,
conversations=[],
)
class TestResponseWrapperInit:
"""Tests for ResponseWrapper initialization."""
@pytest.mark.asyncio
async def test_initialize_passes(self):
"""Initialize should complete without error."""
wrapper = get_wrapper_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
pipeline_config = {}
await stage.initialize(pipeline_config)
class TestResponseWrapperMessageChain:
"""Tests for MessageChain wrapping."""
@pytest.mark.asyncio
async def test_message_chain_direct_append(self):
"""MessageChain in resp_messages should be directly appended."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_messages = [
platform_message.MessageChain([platform_message.Plain(text="response")])
]
query.resp_message_chain = []
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
assert len(results[0].new_query.resp_message_chain) == 1
class TestResponseWrapperCommand:
"""Tests for command response wrapping."""
@pytest.mark.asyncio
async def test_command_response_prefix(self):
"""Command response should have [bot] prefix."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create a command response message
command_resp = Mock()
command_resp.role = 'command'
command_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")])
)
query.resp_messages = [command_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
# Check that prefix was added (via get_content_platform_message_chain)
command_resp.get_content_platform_message_chain.assert_called_once()
class TestResponseWrapperPlugin:
"""Tests for plugin response wrapping."""
@pytest.mark.asyncio
async def test_plugin_response_direct(self):
"""Plugin response should be wrapped without prefix."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create a plugin response message
plugin_resp = Mock()
plugin_resp.role = 'plugin'
plugin_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")])
)
query.resp_messages = [plugin_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
class TestResponseWrapperAssistant:
"""Tests for assistant response wrapping."""
@pytest.mark.asyncio
async def test_assistant_content_response(self):
"""Assistant with content should emit event and wrap."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector - normal event (not prevented)
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = None
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello back!"
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
# Event should have been emitted
app.plugin_connector.emit_event.assert_called()
@pytest.mark.asyncio
async def test_assistant_empty_content(self):
"""Assistant with empty content should not emit event."""
wrapper = get_wrapper_module()
app = FakeApp()
app.plugin_connector.emit_event = AsyncMock()
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with empty content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = None
assistant_resp.tool_calls = None
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert results == []
assert query.resp_message_chain == []
app.plugin_connector.emit_event.assert_not_called()
@pytest.mark.asyncio
async def test_assistant_tool_calls(self):
"""Assistant with tool_calls should show function call message."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = None
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
pipeline_config['output']['misc']['track-function-calls'] = True
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with tool_calls
mock_tool_call = Mock()
mock_tool_call.function = Mock()
mock_tool_call.function.name = 'test_function'
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Processing..."
assistant_resp.tool_calls = [mock_tool_call]
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 2
for result in results:
assert result.result_type == entities.ResultType.CONTINUE
assert app.plugin_connector.emit_event.await_count == 2
class TestResponseWrapperInterrupt:
"""Tests for INTERRUPT behavior when plugin prevents default."""
@pytest.mark.asyncio
async def test_event_prevented_interrupts(self):
"""Plugin event prevented should return INTERRUPT."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector - event is prevented
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=True)
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello!"
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.INTERRUPT
class TestResponseWrapperCustomReply:
"""Tests for custom reply from plugin event."""
@pytest.mark.asyncio
async def test_custom_reply_chain_used(self):
"""Plugin reply_message_chain should replace default."""
wrapper = get_wrapper_module()
entities = get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector with custom reply
custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = custom_chain
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Default reply"
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
assert len(results) == 1
assert results[0].result_type == entities.ResultType.CONTINUE
# Custom chain should be in resp_message_chain
assert len(results[0].new_query.resp_message_chain) == 1
# Should be the custom chain
chain = results[0].new_query.resp_message_chain[0]
assert "Custom reply" in str(chain)
class TestResponseWrapperVariables:
"""Tests for bound plugins variable."""
@pytest.mark.asyncio
async def test_bound_plugins_passed_to_event(self):
"""_pipeline_bound_plugins should be passed to emit_event."""
wrapper = get_wrapper_module()
get_entities_module()
app = FakeApp()
# Mock session manager to return a valid Session
session = make_session()
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
mock_event_ctx.event.reply_message_chain = None
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = wrapper.ResponseWrapper(app)
pipeline_config = make_wrapper_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.resp_message_chain = []
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello"
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")])
)
query.resp_messages = [assistant_resp]
results = []
async for result in stage.process(query, 'ResponseWrapper'):
results.append(result)
# Check that bound_plugins was passed
emit_call = app.plugin_connector.emit_event.call_args
assert emit_call[0][1] == ['plugin1', 'plugin2'] # Second argument is bound_plugins