Files
LangBot/tests/unit_tests/pipeline/test_msgtrun.py
huanghuoguoguo 1a3c73bc05 test(quality): fix fake tests and add missing coverage
P0 fixes:
- telemetry: rewrite fake tests with real behavior verification (25 tests)
- config: delete copied-source tests, use proper imports (2 deleted)
- persistence: fix try-except pass to verify specific errors

P1 fixes:
- pipeline: add real FixedWindowAlgo tests instead of mocks (12 tests)
- provider: add SessionManager and ToolManager tests (25 tests)
- storage: add S3StorageProvider tests with moto mock (16 tests)
- plugin: add handler action tests for setting inheritance (15 tests)
- rag: add file storage and ZIP processing tests (21 tests)
- vector: add VDB filter conversion tests (30 tests)

P2 fixes:
- pipeline/msgtrun: strengthen assertions for exact message count
- api: add response structure validation in integration tests

New test files:
- provider/test_session_manager.py
- provider/test_tool_manager.py
- storage/test_s3storage.py
- plugin/test_handler_actions.py
- rag/test_file_storage.py
- vector/test_vdb_filter_conversion.py

Source code bugs documented:
- provider: TokenManager.next_token() ZeroDivisionError
- telemetry: send_tasks class variable shared state
- command: empty command IndexError, unused parameters
- utils: funcschema KeyError
- entity: vector.py independent declarative_base

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-16 10:13:15 +08:00

324 lines
11 KiB
Python

"""
Unit tests for ConversationMessageTruncator (msgtrun) pipeline stage.
Tests cover:
- Normal truncation behavior based on max-round
- Boundary length handling
- Empty message handling
- Multi-message chain truncation
"""
from __future__ import annotations
import pytest
from importlib import import_module
from tests.factories import (
FakeApp,
text_query,
)
import langbot_plugin.api.entities.builtin.provider.message as provider_message
def get_msgtrun_module():
"""Lazy import to avoid circular import issues."""
# Import pipelinemgr first to trigger stage registration
import_module('langbot.pkg.pipeline.pipelinemgr')
return import_module('langbot.pkg.pipeline.msgtrun.msgtrun')
def get_truncator_module():
"""Lazy import for truncator base."""
return import_module('langbot.pkg.pipeline.msgtrun.truncator')
def get_entities_module():
"""Lazy import for pipeline entities."""
return import_module('langbot.pkg.pipeline.entities')
def get_round_truncator_module():
"""Lazy import for round truncator."""
return import_module('langbot.pkg.pipeline.msgtrun.truncators.round')
def make_truncate_config(max_round: int = 5):
"""Create a pipeline config with max-round setting."""
return {
'ai': {
'local-agent': {
'max-round': max_round,
}
}
}
class TestConversationMessageTruncatorInit:
"""Tests for ConversationMessageTruncator initialization."""
@pytest.mark.asyncio
async def test_initialize_round_truncator(self):
"""Initialize should select 'round' truncator by default."""
msgtrun = get_msgtrun_module()
truncator = get_truncator_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config()
await stage.initialize(pipeline_config)
assert stage.trun is not None
assert isinstance(stage.trun, truncator.Truncator)
@pytest.mark.asyncio
async def test_initialize_unknown_truncator_raises(self):
"""Initialize with unknown truncator method should raise ValueError."""
msgtrun = get_msgtrun_module()
truncator = get_truncator_module()
# Save original preregistered_truncators
original_truncators = truncator.preregistered_truncators.copy()
try:
# Clear registered truncators to simulate unknown method
truncator.preregistered_truncators = []
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config()
with pytest.raises(ValueError, match='Unknown truncator'):
await stage.initialize(pipeline_config)
finally:
# Restore original truncators
truncator.preregistered_truncators = original_truncators
class TestRoundTruncatorProcess:
"""Tests for RoundTruncator truncation behavior."""
@pytest.mark.asyncio
async def test_truncate_within_limit(self):
"""Messages within max-round limit should not be truncated."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config(max_round=5)
await stage.initialize(pipeline_config)
# Create query with 3 messages (within limit)
query = text_query("current message")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
provider_message.Message(role='assistant', content='response 1'),
provider_message.Message(role='user', content='message 2'),
provider_message.Message(role='assistant', content='response 2'),
provider_message.Message(role='user', content='current message'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
# All messages should be preserved
assert len(result.new_query.messages) == 5
@pytest.mark.asyncio
async def test_truncate_exceeds_limit(self):
"""Messages exceeding max-round should be truncated precisely.
Algorithm: traverse backwards, collect while current_round < max_round, count user messages as rounds.
For max_round=2 with 7 messages (u1, a1, u2, a2, u3, a3, u_current):
- Iterate: u_current(r=0<2, collect, r=1), a3(r=1<2, collect), u3(r=1<2, collect, r=2)
- a2: r=2 not < 2 → break
- Collected reverse: [u_current, a3, u3]
- Reversed: [u3, a3, u_current] = 3 messages
"""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config(max_round=2) # Only keep 2 rounds
await stage.initialize(pipeline_config)
# Create query with many messages exceeding limit
# 7 messages = 3 full rounds + 1 current user
query = text_query("current message")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
provider_message.Message(role='assistant', content='response 1'),
provider_message.Message(role='user', content='message 2'),
provider_message.Message(role='assistant', content='response 2'),
provider_message.Message(role='user', content='message 3'),
provider_message.Message(role='assistant', content='response 3'),
provider_message.Message(role='user', content='current message'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
# Should keep exactly 3 messages: message3, response3, current message
messages = result.new_query.messages
assert len(messages) == 3
# Verify exact message content
assert messages[0].role == 'user'
assert messages[0].content == 'message 3'
assert messages[1].role == 'assistant'
assert messages[1].content == 'response 3'
assert messages[2].role == 'user'
assert messages[2].content == 'current message'
@pytest.mark.asyncio
async def test_truncate_empty_messages(self):
"""Empty messages list should return empty list."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.messages = []
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
assert len(result.new_query.messages) == 0
@pytest.mark.asyncio
async def test_truncate_single_message(self):
"""Single message should be preserved."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config()
await stage.initialize(pipeline_config)
query = text_query("hello")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='hello'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
assert len(result.new_query.messages) == 1
@pytest.mark.asyncio
async def test_truncate_preserves_order(self):
"""Truncation should preserve message order."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config(max_round=2)
await stage.initialize(pipeline_config)
query = text_query("current")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='user1'),
provider_message.Message(role='assistant', content='asst1'),
provider_message.Message(role='user', content='user2'),
provider_message.Message(role='assistant', content='asst2'),
provider_message.Message(role='user', content='user3'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
# Check order is preserved (user2 -> asst2 -> user3)
messages = result.new_query.messages
if len(messages) >= 3:
assert messages[0].role == 'user'
assert messages[0].content == 'user2'
assert messages[1].role == 'assistant'
assert messages[1].content == 'asst2'
@pytest.mark.asyncio
async def test_truncate_max_round_one(self):
"""max-round=1 should only keep last user message."""
msgtrun = get_msgtrun_module()
entities = get_entities_module()
app = FakeApp()
stage = msgtrun.ConversationMessageTruncator(app)
pipeline_config = make_truncate_config(max_round=1)
await stage.initialize(pipeline_config)
query = text_query("current")
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='old1'),
provider_message.Message(role='assistant', content='old1_resp'),
provider_message.Message(role='user', content='current'),
]
result = await stage.process(query, 'ConversationMessageTruncator')
assert result.result_type == entities.ResultType.CONTINUE
# Only last round (user + assistant pair) should remain
messages = result.new_query.messages
# At most 2 messages (user + assistant before current)
assert len(messages) <= 2
class TestRoundTruncatorDirect:
"""Direct tests for RoundTruncator class."""
@pytest.mark.asyncio
async def test_round_truncator_direct_process(self):
"""Test RoundTruncator truncate method directly."""
truncator_mod = get_truncator_module()
app = FakeApp()
# Get the RoundTruncator class from preregistered
for trun_cls in truncator_mod.preregistered_truncators:
if trun_cls.name == 'round':
trun = trun_cls(app)
break
query = text_query("hello")
query.pipeline_config = make_truncate_config(max_round=3)
query.messages = [
provider_message.Message(role='user', content='m1'),
provider_message.Message(role='assistant', content='r1'),
provider_message.Message(role='user', content='m2'),
provider_message.Message(role='assistant', content='r2'),
provider_message.Message(role='user', content='hello'),
]
result = await trun.truncate(query)
assert result is not None
assert hasattr(result, 'messages')