mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
322 lines
11 KiB
Python
322 lines
11 KiB
Python
"""
|
|
Unit tests for ConversationMessageTruncator (msgtrun) pipeline stage.
|
|
|
|
Tests cover:
|
|
- Normal truncation behavior based on max-round
|
|
- Boundary length handling
|
|
- Empty message handling
|
|
- Multi-message chain truncation
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
from importlib import import_module
|
|
|
|
from tests.factories import (
|
|
FakeApp,
|
|
text_query,
|
|
)
|
|
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
|
|
def get_msgtrun_module():
|
|
"""Lazy import to avoid circular import issues."""
|
|
# Import pipelinemgr first to trigger stage registration
|
|
import_module('langbot.pkg.pipeline.pipelinemgr')
|
|
return import_module('langbot.pkg.pipeline.msgtrun.msgtrun')
|
|
|
|
|
|
def get_truncator_module():
|
|
"""Lazy import for truncator base."""
|
|
return import_module('langbot.pkg.pipeline.msgtrun.truncator')
|
|
|
|
|
|
def get_entities_module():
|
|
"""Lazy import for pipeline entities."""
|
|
return import_module('langbot.pkg.pipeline.entities')
|
|
|
|
|
|
def get_round_truncator_module():
|
|
"""Lazy import for round truncator."""
|
|
return import_module('langbot.pkg.pipeline.msgtrun.truncators.round')
|
|
|
|
|
|
def make_truncate_config(max_round: int = 5):
|
|
"""Create a pipeline config with max-round setting."""
|
|
return {
|
|
'ai': {
|
|
'local-agent': {
|
|
'max-round': max_round,
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
class TestConversationMessageTruncatorInit:
|
|
"""Tests for ConversationMessageTruncator initialization."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialize_round_truncator(self):
|
|
"""Initialize should select 'round' truncator by default."""
|
|
msgtrun = get_msgtrun_module()
|
|
truncator = get_truncator_module()
|
|
|
|
app = FakeApp()
|
|
stage = msgtrun.ConversationMessageTruncator(app)
|
|
|
|
pipeline_config = make_truncate_config()
|
|
|
|
await stage.initialize(pipeline_config)
|
|
|
|
assert stage.trun is not None
|
|
assert isinstance(stage.trun, truncator.Truncator)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialize_unknown_truncator_raises(self):
|
|
"""Initialize with unknown truncator method should raise ValueError."""
|
|
msgtrun = get_msgtrun_module()
|
|
truncator = get_truncator_module()
|
|
|
|
# Save original preregistered_truncators
|
|
original_truncators = truncator.preregistered_truncators.copy()
|
|
|
|
try:
|
|
# Clear registered truncators to simulate unknown method
|
|
truncator.preregistered_truncators = []
|
|
|
|
app = FakeApp()
|
|
stage = msgtrun.ConversationMessageTruncator(app)
|
|
|
|
pipeline_config = make_truncate_config()
|
|
|
|
with pytest.raises(ValueError, match='Unknown truncator'):
|
|
await stage.initialize(pipeline_config)
|
|
finally:
|
|
# Restore original truncators
|
|
truncator.preregistered_truncators = original_truncators
|
|
|
|
|
|
class TestRoundTruncatorProcess:
|
|
"""Tests for RoundTruncator truncation behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_within_limit(self):
|
|
"""Messages within max-round limit should not be truncated."""
|
|
msgtrun = get_msgtrun_module()
|
|
entities = get_entities_module()
|
|
|
|
app = FakeApp()
|
|
stage = msgtrun.ConversationMessageTruncator(app)
|
|
|
|
pipeline_config = make_truncate_config(max_round=5)
|
|
|
|
await stage.initialize(pipeline_config)
|
|
|
|
# Create query with 3 messages (within limit)
|
|
query = text_query("current message")
|
|
query.pipeline_config = pipeline_config
|
|
query.messages = [
|
|
provider_message.Message(role='user', content='message 1'),
|
|
provider_message.Message(role='assistant', content='response 1'),
|
|
provider_message.Message(role='user', content='message 2'),
|
|
provider_message.Message(role='assistant', content='response 2'),
|
|
provider_message.Message(role='user', content='current message'),
|
|
]
|
|
|
|
result = await stage.process(query, 'ConversationMessageTruncator')
|
|
|
|
assert result.result_type == entities.ResultType.CONTINUE
|
|
# All messages should be preserved
|
|
assert len(result.new_query.messages) == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_exceeds_limit(self):
|
|
"""Messages exceeding max-round should be truncated precisely.
|
|
|
|
Algorithm: traverse backwards, collect while current_round < max_round, count user messages as rounds.
|
|
For max_round=2 with 7 messages (u1, a1, u2, a2, u3, a3, u_current):
|
|
- Iterate: u_current(r=0<2, collect, r=1), a3(r=1<2, collect), u3(r=1<2, collect, r=2)
|
|
- a2: r=2 not < 2 → break
|
|
- Collected reverse: [u_current, a3, u3]
|
|
- Reversed: [u3, a3, u_current] = 3 messages
|
|
"""
|
|
msgtrun = get_msgtrun_module()
|
|
entities = get_entities_module()
|
|
|
|
app = FakeApp()
|
|
stage = msgtrun.ConversationMessageTruncator(app)
|
|
|
|
pipeline_config = make_truncate_config(max_round=2) # Only keep 2 rounds
|
|
|
|
await stage.initialize(pipeline_config)
|
|
|
|
# Create query with many messages exceeding limit
|
|
# 7 messages = 3 full rounds + 1 current user
|
|
query = text_query("current message")
|
|
query.pipeline_config = pipeline_config
|
|
query.messages = [
|
|
provider_message.Message(role='user', content='message 1'),
|
|
provider_message.Message(role='assistant', content='response 1'),
|
|
provider_message.Message(role='user', content='message 2'),
|
|
provider_message.Message(role='assistant', content='response 2'),
|
|
provider_message.Message(role='user', content='message 3'),
|
|
provider_message.Message(role='assistant', content='response 3'),
|
|
provider_message.Message(role='user', content='current message'),
|
|
]
|
|
|
|
result = await stage.process(query, 'ConversationMessageTruncator')
|
|
|
|
assert result.result_type == entities.ResultType.CONTINUE
|
|
# Should keep exactly 3 messages: message3, response3, current message
|
|
messages = result.new_query.messages
|
|
assert len(messages) == 3
|
|
|
|
# Verify exact message content
|
|
assert messages[0].role == 'user'
|
|
assert messages[0].content == 'message 3'
|
|
assert messages[1].role == 'assistant'
|
|
assert messages[1].content == 'response 3'
|
|
assert messages[2].role == 'user'
|
|
assert messages[2].content == 'current message'
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_empty_messages(self):
|
|
"""Empty messages list should return empty list."""
|
|
msgtrun = get_msgtrun_module()
|
|
entities = get_entities_module()
|
|
|
|
app = FakeApp()
|
|
stage = msgtrun.ConversationMessageTruncator(app)
|
|
|
|
pipeline_config = make_truncate_config()
|
|
|
|
await stage.initialize(pipeline_config)
|
|
|
|
query = text_query("hello")
|
|
query.pipeline_config = pipeline_config
|
|
query.messages = []
|
|
|
|
result = await stage.process(query, 'ConversationMessageTruncator')
|
|
|
|
assert result.result_type == entities.ResultType.CONTINUE
|
|
assert len(result.new_query.messages) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_single_message(self):
|
|
"""Single message should be preserved."""
|
|
msgtrun = get_msgtrun_module()
|
|
entities = get_entities_module()
|
|
|
|
app = FakeApp()
|
|
stage = msgtrun.ConversationMessageTruncator(app)
|
|
|
|
pipeline_config = make_truncate_config()
|
|
|
|
await stage.initialize(pipeline_config)
|
|
|
|
query = text_query("hello")
|
|
query.pipeline_config = pipeline_config
|
|
query.messages = [
|
|
provider_message.Message(role='user', content='hello'),
|
|
]
|
|
|
|
result = await stage.process(query, 'ConversationMessageTruncator')
|
|
|
|
assert result.result_type == entities.ResultType.CONTINUE
|
|
assert len(result.new_query.messages) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_preserves_order(self):
|
|
"""Truncation should preserve message order."""
|
|
msgtrun = get_msgtrun_module()
|
|
entities = get_entities_module()
|
|
|
|
app = FakeApp()
|
|
stage = msgtrun.ConversationMessageTruncator(app)
|
|
|
|
pipeline_config = make_truncate_config(max_round=2)
|
|
|
|
await stage.initialize(pipeline_config)
|
|
|
|
query = text_query("current")
|
|
query.pipeline_config = pipeline_config
|
|
query.messages = [
|
|
provider_message.Message(role='user', content='user1'),
|
|
provider_message.Message(role='assistant', content='asst1'),
|
|
provider_message.Message(role='user', content='user2'),
|
|
provider_message.Message(role='assistant', content='asst2'),
|
|
provider_message.Message(role='user', content='user3'),
|
|
]
|
|
|
|
result = await stage.process(query, 'ConversationMessageTruncator')
|
|
|
|
assert result.result_type == entities.ResultType.CONTINUE
|
|
|
|
messages = result.new_query.messages
|
|
assert [(msg.role, msg.content) for msg in messages] == [
|
|
('user', 'user2'),
|
|
('assistant', 'asst2'),
|
|
('user', 'user3'),
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncate_max_round_one(self):
|
|
"""max-round=1 should only keep last user message."""
|
|
msgtrun = get_msgtrun_module()
|
|
entities = get_entities_module()
|
|
|
|
app = FakeApp()
|
|
stage = msgtrun.ConversationMessageTruncator(app)
|
|
|
|
pipeline_config = make_truncate_config(max_round=1)
|
|
|
|
await stage.initialize(pipeline_config)
|
|
|
|
query = text_query("current")
|
|
query.pipeline_config = pipeline_config
|
|
query.messages = [
|
|
provider_message.Message(role='user', content='old1'),
|
|
provider_message.Message(role='assistant', content='old1_resp'),
|
|
provider_message.Message(role='user', content='current'),
|
|
]
|
|
|
|
result = await stage.process(query, 'ConversationMessageTruncator')
|
|
|
|
assert result.result_type == entities.ResultType.CONTINUE
|
|
messages = result.new_query.messages
|
|
assert [(msg.role, msg.content) for msg in messages] == [('user', 'current')]
|
|
|
|
|
|
class TestRoundTruncatorDirect:
|
|
"""Direct tests for RoundTruncator class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_truncator_direct_process(self):
|
|
"""Test RoundTruncator truncate method directly."""
|
|
truncator_mod = get_truncator_module()
|
|
|
|
app = FakeApp()
|
|
|
|
# Get the RoundTruncator class from preregistered
|
|
for trun_cls in truncator_mod.preregistered_truncators:
|
|
if trun_cls.name == 'round':
|
|
trun = trun_cls(app)
|
|
break
|
|
|
|
query = text_query("hello")
|
|
query.pipeline_config = make_truncate_config(max_round=3)
|
|
query.messages = [
|
|
provider_message.Message(role='user', content='m1'),
|
|
provider_message.Message(role='assistant', content='r1'),
|
|
provider_message.Message(role='user', content='m2'),
|
|
provider_message.Message(role='assistant', content='r2'),
|
|
provider_message.Message(role='user', content='hello'),
|
|
]
|
|
|
|
result = await trun.truncate(query)
|
|
|
|
assert result is not None
|
|
assert hasattr(result, 'messages')
|