Files
LangBot/tests/unit_tests/pipeline/test_msgtrun.py
huanghuoguoguo 70ec75f9a2 feat(test): Phase 1.5 coverage expansion - COV-001 to COV-013
Coverage baseline raised from 13.65% to 26% (+12.35%)
Gate raised from 12% to 18%

Tasks completed:
- COV-001: Command system unit tests (100% coverage)
- COV-002: API service unit tests batch 1 (user/apikey/model/provider)
- COV-003: Provider model manager unit tests
- COV-004: Pipeline remaining stage tests (aggregator/cntfilter/longtext/msgtrun)
- COV-005: Storage and utils coverage pass
- COV-006: Gate ratchet 12%→15%
- COV-007: Gate ratchet 15%→18%
- COV-008: API service batch 2 (bot/pipeline/webhook/space/maintenance/mcp)
- COV-009: Blocked - API controller circular import issue documented
- COV-010: Plugin runtime unit tests (+0.08%)
- COV-011: RAG and vector unit tests (+0.68%)
- COV-012: Core boot and migration unit tests
- COV-013: Provider requester logic unit tests (+0.62%)

Key additions:
- tests/utils/import_isolation.py: sys.modules isolation for circular imports
- Provider requester mock tests: proved HTTP-dependent code can be tested locally
- Vector filter utilities: 100% coverage on pure functions
- API services: fake persistence pattern for unit testing

Blocked issue COV-009 documented in langbot-test-plan/1.5/issues/

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-16 10:12:48 +08:00

307 lines
10 KiB
Python

"""
Unit tests for ConversationMessageTruncator (msgtrun) pipeline stage.
Tests cover:
- Normal truncation behavior based on max-round
- Boundary length handling
- Empty message handling
- Multi-message chain truncation
"""
from __future__ import annotations
import pytest
from importlib import import_module
from tests.factories import (
FakeApp,
text_query,
)
import langbot_plugin.api.entities.builtin.provider.message as provider_message
def get_msgtrun_module():
"""Lazy import to avoid circular import issues."""
# Import pipelinemgr first to trigger stage registration
import_module('langbot.pkg.pipeline.pipelinemgr')
return import_module('langbot.pkg.pipeline.msgtrun.msgtrun')
def get_truncator_module():
"""Lazy import for truncator base."""
return import_module('langbot.pkg.pipeline.msgtrun.truncator')
def get_entities_module():
"""Lazy import for pipeline entities."""
return import_module('langbot.pkg.pipeline.entities')
def get_round_truncator_module():
"""Lazy import for round truncator."""
return import_module('langbot.pkg.pipeline.msgtrun.truncators.round')
def make_truncate_config(max_round: int = 5):
"""Create a pipeline config with max-round setting."""
return {
'ai': {
'local-agent': {
'max-round': max_round,
}
}
}
class TestConversationMessageTruncatorInit:
"""Tests for ConversationMessageTruncator initialization."""
@pytest.mark.asyncio
async def test_initialize_round_truncator(self):
"""Initialize should select 'round' truncator by default."""
msgtrun = get_msgtrun_module()
truncator = get_truncator_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config()
await stage.initialize(pipeline_config)
assert stage.trun is not None
assert isinstance(stage.trun, truncator.Truncator)
@pytest.mark.asyncio
async def test_initialize_unknown_truncator_raises(self):
"""Initialize with unknown truncator method should raise ValueError."""
msgtrun = get_msgtrun_module()
truncator = get_truncator_module()
# Save original preregistered_truncators
original_truncators = truncator.preregistered_truncators.copy()
try:
# Clear registered truncators to simulate unknown method
truncator.preregistered_truncators = []
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config()
with pytest.raises(ValueError, match='Unknown truncator'):
await stage.initialize(pipeline_config)
finally:
# Restore original truncators
truncator.preregistered_truncators = original_truncators
class TestRoundTruncatorProcess:
"""Tests for RoundTruncator truncation behavior."""
@pytest.mark.asyncio
async def test_truncate_within_limit(self):
"""Messages within max-round limit should not be truncated."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config(max_round=5)
await stage.initialize(pipeline_config)
# Create query with 3 messages (within limit)
query = text_query("current message")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
provider_message.Message(role='assistant', content='response 1'),
provider_message.Message(role='user', content='message 2'),
provider_message.Message(role='assistant', content='response 2'),
provider_message.Message(role='user', content='current message'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
# All messages should be preserved
assert len(result.new_query.messages) == 5
@pytest.mark.asyncio
async def test_truncate_exceeds_limit(self):
"""Messages exceeding max-round should be truncated."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config(max_round=2) # Only keep 2 rounds
await stage.initialize(pipeline_config)
# Create query with many messages exceeding limit
query = text_query("current message")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
provider_message.Message(role='assistant', content='response 1'),
provider_message.Message(role='user', content='message 2'),
provider_message.Message(role='assistant', content='response 2'),
provider_message.Message(role='user', content='message 3'),
provider_message.Message(role='assistant', content='response 3'),
provider_message.Message(role='user', content='current message'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
# Should only keep last 2 rounds (2 user messages)
# Each round = user + assistant, so 2 rounds = 4 messages + current = 5
assert len(result.new_query.messages) <= 5
@pytest.mark.asyncio
async def test_truncate_empty_messages(self):
"""Empty messages list should return empty list."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.messages = []
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
assert len(result.new_query.messages) == 0
@pytest.mark.asyncio
async def test_truncate_single_message(self):
"""Single message should be preserved."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='hello'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
assert len(result.new_query.messages) == 1
@pytest.mark.asyncio
async def test_truncate_preserves_order(self):
"""Truncation should preserve message order."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config(max_round=2)
await stage.initialize(pipeline_config)
query = text_query("current")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='user1'),
provider_message.Message(role='assistant', content='asst1'),
provider_message.Message(role='user', content='user2'),
provider_message.Message(role='assistant', content='asst2'),
provider_message.Message(role='user', content='user3'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
# Check order is preserved (user2 -> asst2 -> user3)
messages = result.new_query.messages
if len(messages) >= 3:
assert messages[0].role == 'user'
assert messages[0].content == 'user2'
assert messages[1].role == 'assistant'
assert messages[1].content == 'asst2'
@pytest.mark.asyncio
async def test_truncate_max_round_one(self):
"""max-round=1 should only keep last user message."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config(max_round=1)
await stage.initialize(pipeline_config)
query = text_query("current")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='old1'),
provider_message.Message(role='assistant', content='old1_resp'),
provider_message.Message(role='user', content='current'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
# Only last round (user + assistant pair) should remain
messages = result.new_query.messages
# At most 2 messages (user + assistant before current)
assert len(messages) <= 2
class TestRoundTruncatorDirect:
"""Direct tests for RoundTruncator class."""
@pytest.mark.asyncio
async def test_round_truncator_direct_process(self):
"""Test RoundTruncator truncate method directly."""
truncator_mod = get_truncator_module()
app = FakeApp()
# Get the RoundTruncator class from preregistered
for trun_cls in truncator_mod.preregistered_truncators:
if trun_cls.name == 'round':
trun = trun_cls(app)
break
query = text_query("hello")
query.pipeline_config = make_truncate_config(max_round=3)
query.messages = [
provider_message.Message(role='user', content='m1'),
provider_message.Message(role='assistant', content='r1'),
provider_message.Message(role='user', content='m2'),
provider_message.Message(role='assistant', content='r2'),
provider_message.Message(role='user', content='hello'),
]
result = await trun.truncate(query)
assert result is not None
assert hasattr(result, 'messages')