mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 08:16:03 +00:00
refactor(agent-runner): remove host context windowing
This commit is contained in:
@@ -120,7 +120,7 @@ class TestResolveRunnerConfig:
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'model': 'uuid-123',
|
||||
'max_round': 10,
|
||||
'custom_option': 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -130,7 +130,7 @@ class TestResolveRunnerConfig:
|
||||
pipeline_config,
|
||||
'plugin:langbot/local-agent/default',
|
||||
)
|
||||
assert config == {'model': 'uuid-123', 'max_round': 10}
|
||||
assert config == {'model': 'uuid-123', 'custom_option': 10}
|
||||
|
||||
def test_resolve_old_format_config(self):
|
||||
"""Runtime config resolver should not read old format."""
|
||||
@@ -138,7 +138,7 @@ class TestResolveRunnerConfig:
|
||||
'ai': {
|
||||
'local-agent': {
|
||||
'model': 'uuid-123',
|
||||
'max_round': 10,
|
||||
'custom_option': 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -155,7 +155,7 @@ class TestResolveRunnerConfig:
|
||||
'ai': {
|
||||
'local-agent': {
|
||||
'model': 'uuid-123',
|
||||
'max_round': 10,
|
||||
'custom_option': 10,
|
||||
'knowledge-base': 'kb-123',
|
||||
},
|
||||
},
|
||||
@@ -165,7 +165,7 @@ class TestResolveRunnerConfig:
|
||||
pipeline_config,
|
||||
'plugin:langbot/local-agent/default',
|
||||
)
|
||||
assert config == {'model': 'uuid-123', 'max_round': 10, 'knowledge-bases': ['kb-123']}
|
||||
assert config == {'model': 'uuid-123', 'custom_option': 10, 'knowledge-bases': ['kb-123']}
|
||||
assert 'knowledge-base' not in config
|
||||
|
||||
def test_resolve_no_config(self):
|
||||
|
||||
@@ -20,7 +20,6 @@ class TestMigratePipelineConfig:
|
||||
},
|
||||
'local-agent': {
|
||||
'model': {'primary': 'model-uuid', 'fallbacks': []},
|
||||
'max-round': 10,
|
||||
'knowledge-base': 'kb-uuid',
|
||||
'prompt': [{'role': 'system', 'content': 'Hello'}],
|
||||
},
|
||||
@@ -35,9 +34,9 @@ class TestMigratePipelineConfig:
|
||||
|
||||
# Config should be in runner_config
|
||||
assert 'plugin:langbot/local-agent/default' in migrated['ai']['runner_config']
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default']['max-round'] == 10
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default']['knowledge-bases'] == ['kb-uuid']
|
||||
assert 'knowledge-base' not in migrated['ai']['runner_config']['plugin:langbot/local-agent/default']
|
||||
assert 'max-round' not in migrated['ai']['runner_config']['plugin:langbot/local-agent/default']
|
||||
|
||||
# Expire-time preserved
|
||||
assert migrated['ai']['runner']['expire-time'] == 0
|
||||
@@ -76,7 +75,7 @@ class TestMigratePipelineConfig:
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'model': {'primary': '', 'fallbacks': []},
|
||||
'max-round': 10,
|
||||
'custom-option': 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -86,7 +85,7 @@ class TestMigratePipelineConfig:
|
||||
|
||||
# Should remain unchanged
|
||||
assert migrated['ai']['runner']['id'] == 'plugin:langbot/local-agent/default'
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default']['max-round'] == 10
|
||||
assert migrated['ai']['runner_config']['plugin:langbot/local-agent/default']['custom-option'] == 10
|
||||
|
||||
def test_new_format_local_agent_config_normalizes_legacy_kb_key(self):
|
||||
"""Migration should normalize legacy KB aliases before runtime."""
|
||||
@@ -260,18 +259,18 @@ class TestResolveRunnerConfig:
|
||||
config = {
|
||||
'ai': {
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {'max-round': 20},
|
||||
'plugin:langbot/local-agent/default': {'custom-option': 20},
|
||||
},
|
||||
},
|
||||
}
|
||||
runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default')
|
||||
assert runner_config['max-round'] == 20
|
||||
assert runner_config['custom-option'] == 20
|
||||
|
||||
def test_resolve_old_format_config(self):
|
||||
"""resolve_runner_config should not read old ai.local-agent at runtime."""
|
||||
config = {
|
||||
'ai': {
|
||||
'local-agent': {'max-round': 15},
|
||||
'local-agent': {'max-round': 15, 'custom-option': 20},
|
||||
},
|
||||
}
|
||||
runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default')
|
||||
@@ -281,21 +280,21 @@ class TestResolveRunnerConfig:
|
||||
"""resolve_legacy_runner_config should read old ai.local-agent for migration."""
|
||||
config = {
|
||||
'ai': {
|
||||
'local-agent': {'max-round': 15},
|
||||
'local-agent': {'max-round': 15, 'custom-option': 20},
|
||||
},
|
||||
}
|
||||
runner_config = ConfigMigration.resolve_legacy_runner_config(config, 'plugin:langbot/local-agent/default')
|
||||
assert runner_config['max-round'] == 15
|
||||
assert runner_config == {'custom-option': 20}
|
||||
|
||||
def test_resolve_new_format_priority(self):
|
||||
"""New format runner_config should take priority."""
|
||||
config = {
|
||||
'ai': {
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {'max-round': 25},
|
||||
'plugin:langbot/local-agent/default': {'custom-option': 25},
|
||||
},
|
||||
'local-agent': {'max-round': 10}, # Old, should be ignored
|
||||
'local-agent': {'max-round': 10, 'custom-option': 10}, # Old, should be ignored
|
||||
},
|
||||
}
|
||||
runner_config = ConfigMigration.resolve_runner_config(config, 'plugin:langbot/local-agent/default')
|
||||
assert runner_config['max-round'] == 25
|
||||
assert runner_config['custom-option'] == 25
|
||||
|
||||
@@ -4,7 +4,7 @@ Tests cover:
|
||||
1. Pipeline Query -> AgentEventEnvelope conversion
|
||||
2. Pipeline config -> AgentBinding conversion
|
||||
3. AgentRunContext not inlining full history by default
|
||||
4. Pipeline max-round only affecting bootstrap/adapter context
|
||||
4. LangBot Host not defining context-window controls
|
||||
5. Event-first run() entry point
|
||||
"""
|
||||
from __future__ import annotations
|
||||
@@ -147,23 +147,13 @@ class TestPipelineConfigToBinding:
|
||||
assert binding.scope.scope_type == "pipeline"
|
||||
assert binding.scope.scope_id == mock_query.pipeline_uuid
|
||||
|
||||
def test_config_to_binding_max_round(self, mock_query_with_max_round):
|
||||
"""Test max_round extraction for Pipeline adapter."""
|
||||
binding = PipelineAdapter.pipeline_config_to_binding(
|
||||
mock_query_with_max_round, "plugin:test/plugin/runner"
|
||||
)
|
||||
|
||||
# max_round should be captured but NOT in Protocol v1 entities
|
||||
assert binding.max_round == 10
|
||||
|
||||
def test_config_to_binding_no_max_round(self, mock_query):
|
||||
"""Test binding without max_round."""
|
||||
def test_config_to_binding_does_not_add_host_context_window(self, mock_query):
|
||||
"""Pipeline binding should not define Host-side context window controls."""
|
||||
binding = PipelineAdapter.pipeline_config_to_binding(
|
||||
mock_query, "plugin:test/plugin/runner"
|
||||
)
|
||||
|
||||
# max_round may be None
|
||||
assert binding.max_round is None
|
||||
assert not hasattr(binding, "max_round")
|
||||
|
||||
|
||||
class TestAgentRunContextProtocolV1:
|
||||
@@ -248,60 +238,23 @@ class TestAgentRunContextProtocolV1:
|
||||
assert ctx.bootstrap is None or isinstance(ctx.bootstrap.messages, list)
|
||||
|
||||
|
||||
class TestMaxRoundNotInProtocol:
|
||||
"""Test that Pipeline max-round only affects adapter context, not Protocol v1."""
|
||||
class TestHostContextWindowNotInProtocol:
|
||||
"""Test that Host-side context window controls are not in Protocol v1."""
|
||||
|
||||
def test_max_round_not_in_sdk_context(self):
|
||||
"""Test max-round is not a field in SDK AgentRunContext."""
|
||||
# AgentRunContext should not have max_round field
|
||||
def test_context_window_not_in_sdk_context(self):
|
||||
"""AgentRunContext should not expose Host-side window controls."""
|
||||
ctx_fields = AgentRunContext.model_fields.keys()
|
||||
|
||||
assert "max_round" not in ctx_fields
|
||||
assert "maxRound" not in ctx_fields
|
||||
|
||||
def test_max_round_in_adapter_context(self):
|
||||
"""Test max_round is in adapter context, not main context."""
|
||||
trigger = AgentTrigger(type="message.received")
|
||||
event = AgentEventContext(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
)
|
||||
input = AgentInput(text="Hello")
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.resources import AgentResources
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.runtime import AgentRuntimeContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.delivery import DeliveryContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.context import AdapterContext
|
||||
|
||||
adapter = AdapterContext(max_round=10)
|
||||
|
||||
ctx = AgentRunContext(
|
||||
run_id="run_1",
|
||||
trigger=trigger,
|
||||
event=event,
|
||||
input=input,
|
||||
delivery=DeliveryContext(surface="platform"),
|
||||
resources=AgentResources(),
|
||||
runtime=AgentRuntimeContext(),
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
# max_round is in adapter context, not main context
|
||||
assert ctx.adapter is not None
|
||||
assert ctx.adapter.max_round == 10
|
||||
|
||||
def test_binding_max_round_for_adapter_only(self, mock_query_with_max_round):
|
||||
"""Test max_round in binding is for adapter use, not Protocol v1."""
|
||||
def test_binding_has_no_context_window_field(self, mock_query):
|
||||
"""Pipeline adapter should not attach context window policy to binding."""
|
||||
binding = PipelineAdapter.pipeline_config_to_binding(
|
||||
mock_query_with_max_round, "plugin:test/plugin/runner"
|
||||
mock_query, "plugin:test/plugin/runner"
|
||||
)
|
||||
|
||||
# max_round is in binding (Host-internal) for Pipeline adapter
|
||||
assert binding.max_round == 10
|
||||
|
||||
# But SDK entities don't have it
|
||||
ctx_fields = AgentRunContext.model_fields.keys()
|
||||
assert "max_round" not in ctx_fields
|
||||
assert not hasattr(binding, "max_round")
|
||||
|
||||
|
||||
class TestSDKCapabilitiesProtocolV1:
|
||||
@@ -416,18 +369,6 @@ def mock_query():
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_with_max_round(mock_query):
|
||||
"""Create a mock Query with max_round configuration."""
|
||||
mock_query.pipeline_config = {
|
||||
"ai": {
|
||||
"runner": "plugin:test/plugin/runner",
|
||||
"max-round": 10,
|
||||
}
|
||||
}
|
||||
return mock_query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_no_session():
|
||||
"""Create a mock Query without session."""
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import datetime
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
@@ -332,8 +332,8 @@ async def test_orchestrator_runs_fake_plugin_with_authorized_context(clean_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_packages_max_round_without_mutating_query(clean_agent_state):
|
||||
"""Test that max-round is packaged without mutating original query."""
|
||||
async def test_orchestrator_does_not_package_query_messages_into_context(clean_agent_state):
|
||||
"""Host should not build an agent working-context window from query.messages."""
|
||||
db_engine = clean_agent_state
|
||||
descriptor = make_descriptor()
|
||||
plugin_connector = FakePluginConnector(
|
||||
@@ -347,7 +347,7 @@ async def test_orchestrator_packages_max_round_without_mutating_query(clean_agen
|
||||
ap = FakeApplication(plugin_connector, db_engine)
|
||||
orchestrator = AgentRunOrchestrator(ap, FakeRegistry(descriptor))
|
||||
query = make_query()
|
||||
query.pipeline_config["ai"]["runner_config"][RUNNER_ID]["max-round"] = 2
|
||||
query.pipeline_config["ai"]["runner_config"][RUNNER_ID]["agent-window"] = 2
|
||||
query.messages = [
|
||||
provider_message.Message(role="user", content="message 1"),
|
||||
provider_message.Message(role="assistant", content="response 1"),
|
||||
@@ -361,21 +361,10 @@ async def test_orchestrator_packages_max_round_without_mutating_query(clean_agen
|
||||
|
||||
assert len(messages) == 1
|
||||
context = plugin_connector.contexts[0]
|
||||
# Protocol v1: messages are in bootstrap.messages
|
||||
assert context["bootstrap"] is not None
|
||||
assert [message["content"] for message in context["bootstrap"]["messages"]] == [
|
||||
"message 2",
|
||||
"response 2",
|
||||
"message 3",
|
||||
"response 3",
|
||||
]
|
||||
# Also exposed in adapter.adapter_messages for runners that consume adapter bootstrap.
|
||||
assert [message["content"] for message in context["adapter"]["adapter_messages"]] == [
|
||||
"message 2",
|
||||
"response 2",
|
||||
"message 3",
|
||||
"response 3",
|
||||
]
|
||||
assert context["config"]["agent-window"] == 2
|
||||
assert context["bootstrap"] is None
|
||||
assert "adapter_messages" not in context["adapter"]
|
||||
assert "context_packaging" not in context["runtime"]["metadata"]
|
||||
assert [message.content for message in query.messages] == [
|
||||
"message 1",
|
||||
"response 1",
|
||||
@@ -384,18 +373,6 @@ async def test_orchestrator_packages_max_round_without_mutating_query(clean_agen
|
||||
"message 3",
|
||||
"response 3",
|
||||
]
|
||||
assert context["runtime"]["metadata"]["context_packaging"] == {
|
||||
"policy": {
|
||||
"mode": "max_round",
|
||||
"max_round": 2,
|
||||
},
|
||||
"history": {
|
||||
"source": "query.messages",
|
||||
"source_total_count": 6,
|
||||
"delivered_count": 4,
|
||||
"messages_complete": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -493,7 +470,7 @@ async def test_orchestrator_enforces_total_runner_deadline(clean_agent_state):
|
||||
|
||||
assert exc_info.value.retryable is True
|
||||
assert "runner.timeout" in str(exc_info.value)
|
||||
assert await get_session_registry().get(plugin_connector.contexts[0]["run_id"]) is None
|
||||
assert await get_session_registry().list_active_runs() == []
|
||||
|
||||
|
||||
class TestPipelineCompatibilityQueryIdInSession:
|
||||
@@ -610,7 +587,7 @@ class TestPipelineAdapterPromptAndParams:
|
||||
],
|
||||
)
|
||||
|
||||
messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
_messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
|
||||
context = plugin_connector.contexts[0]
|
||||
# Prompt should be in adapter.extra
|
||||
@@ -641,7 +618,7 @@ class TestPipelineAdapterPromptAndParams:
|
||||
"another_param": 123,
|
||||
}
|
||||
|
||||
messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
_messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
|
||||
context = plugin_connector.contexts[0]
|
||||
assert context["adapter"]["extra"]["params"] == {
|
||||
@@ -671,7 +648,7 @@ class TestPipelineAdapterPromptAndParams:
|
||||
"_pipeline_bound_plugins": ["plugin1"],
|
||||
}
|
||||
|
||||
messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
_messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
|
||||
context = plugin_connector.contexts[0]
|
||||
params = context["adapter"]["extra"]["params"]
|
||||
@@ -703,7 +680,7 @@ class TestPipelineAdapterPromptAndParams:
|
||||
"credential": "secret000",
|
||||
}
|
||||
|
||||
messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
_messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
|
||||
context = plugin_connector.contexts[0]
|
||||
params = context["adapter"]["extra"]["params"]
|
||||
@@ -735,7 +712,7 @@ class TestPipelineAdapterPromptAndParams:
|
||||
"a_lambda": lambda x: x, # function is not JSON-serializable
|
||||
}
|
||||
|
||||
messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
_messages = [message async for message in orchestrator.run_from_query(query)]
|
||||
|
||||
context = plugin_connector.contexts[0]
|
||||
params = context["adapter"]["extra"]["params"]
|
||||
|
||||
@@ -1,369 +0,0 @@
|
||||
"""
|
||||
Unit tests for ConversationMessageTruncator (msgtrun) pipeline stage.
|
||||
|
||||
Tests cover:
|
||||
- Normal truncation behavior based on max-round
|
||||
- Boundary length handling
|
||||
- Empty message handling
|
||||
- Multi-message chain truncation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
|
||||
from tests.factories import (
|
||||
FakeApp,
|
||||
text_query,
|
||||
)
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
RUNNER_ID = 'plugin:langbot/local-agent/default'
|
||||
|
||||
|
||||
def get_msgtrun_module():
|
||||
"""Lazy import to avoid circular import issues."""
|
||||
# Import pipelinemgr first to trigger stage registration
|
||||
import_module('langbot.pkg.pipeline.pipelinemgr')
|
||||
return import_module('langbot.pkg.pipeline.msgtrun.msgtrun')
|
||||
|
||||
|
||||
def get_truncator_module():
|
||||
"""Lazy import for truncator base."""
|
||||
return import_module('langbot.pkg.pipeline.msgtrun.truncator')
|
||||
|
||||
|
||||
def get_entities_module():
|
||||
"""Lazy import for pipeline entities."""
|
||||
return import_module('langbot.pkg.pipeline.entities')
|
||||
|
||||
|
||||
def get_round_truncator_module():
|
||||
"""Lazy import for round truncator."""
|
||||
return import_module('langbot.pkg.pipeline.msgtrun.truncators.round')
|
||||
|
||||
|
||||
def make_truncate_config(max_round: int = 5):
|
||||
"""Create a pipeline config with max-round setting."""
|
||||
return {
|
||||
'msg-truncate': {
|
||||
'method': 'round',
|
||||
'round': {
|
||||
'max-round': max_round,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def make_agent_runner_config(max_round: int = 5):
|
||||
"""Create an AgentRunner pipeline config with max-round binding config."""
|
||||
return {
|
||||
'ai': {
|
||||
'runner': {'id': RUNNER_ID},
|
||||
'runner_config': {
|
||||
RUNNER_ID: {
|
||||
'max-round': max_round,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestConversationMessageTruncatorInit:
|
||||
"""Tests for ConversationMessageTruncator initialization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_round_truncator(self):
|
||||
"""Initialize should select 'round' truncator by default."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
truncator = get_truncator_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config()
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
assert stage.trun is not None
|
||||
assert isinstance(stage.trun, truncator.Truncator)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_unknown_truncator_raises(self):
|
||||
"""Initialize with unknown truncator method should raise ValueError."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
truncator = get_truncator_module()
|
||||
|
||||
# Save original preregistered_truncators
|
||||
original_truncators = truncator.preregistered_truncators.copy()
|
||||
|
||||
try:
|
||||
# Clear registered truncators to simulate unknown method
|
||||
truncator.preregistered_truncators = []
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config()
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown truncator'):
|
||||
await stage.initialize(pipeline_config)
|
||||
finally:
|
||||
# Restore original truncators
|
||||
truncator.preregistered_truncators = original_truncators
|
||||
|
||||
|
||||
class TestRoundTruncatorProcess:
|
||||
"""Tests for RoundTruncator truncation behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_within_limit(self):
|
||||
"""Messages within max-round limit should not be truncated."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config(max_round=5)
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
# Create query with 3 messages (within limit)
|
||||
query = text_query("current message")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='message 1'),
|
||||
provider_message.Message(role='assistant', content='response 1'),
|
||||
provider_message.Message(role='user', content='message 2'),
|
||||
provider_message.Message(role='assistant', content='response 2'),
|
||||
provider_message.Message(role='user', content='current message'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
# All messages should be preserved
|
||||
assert len(result.new_query.messages) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_runner_path_skips_pipeline_truncation(self):
|
||||
"""AgentRunner path should leave query.messages intact at pipeline stage."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_agent_runner_config(max_round=1)
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("current")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='old1'),
|
||||
provider_message.Message(role='assistant', content='old1_resp'),
|
||||
provider_message.Message(role='user', content='current'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert [(msg.role, msg.content) for msg in result.new_query.messages] == [
|
||||
('user', 'old1'),
|
||||
('assistant', 'old1_resp'),
|
||||
('user', 'current'),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_exceeds_limit(self):
|
||||
"""Messages exceeding max-round should be truncated precisely.
|
||||
|
||||
Algorithm: traverse backwards, collect while current_round < max_round, count user messages as rounds.
|
||||
For max_round=2 with 7 messages (u1, a1, u2, a2, u3, a3, u_current):
|
||||
- Iterate: u_current(r=0<2, collect, r=1), a3(r=1<2, collect), u3(r=1<2, collect, r=2)
|
||||
- a2: r=2 not < 2 → break
|
||||
- Collected reverse: [u_current, a3, u3]
|
||||
- Reversed: [u3, a3, u_current] = 3 messages
|
||||
"""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config(max_round=2) # Only keep 2 rounds
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
# Create query with many messages exceeding limit
|
||||
# 7 messages = 3 full rounds + 1 current user
|
||||
query = text_query("current message")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='message 1'),
|
||||
provider_message.Message(role='assistant', content='response 1'),
|
||||
provider_message.Message(role='user', content='message 2'),
|
||||
provider_message.Message(role='assistant', content='response 2'),
|
||||
provider_message.Message(role='user', content='message 3'),
|
||||
provider_message.Message(role='assistant', content='response 3'),
|
||||
provider_message.Message(role='user', content='current message'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
# Should keep exactly 3 messages: message3, response3, current message
|
||||
messages = result.new_query.messages
|
||||
assert len(messages) == 3
|
||||
|
||||
# Verify exact message content
|
||||
assert messages[0].role == 'user'
|
||||
assert messages[0].content == 'message 3'
|
||||
assert messages[1].role == 'assistant'
|
||||
assert messages[1].content == 'response 3'
|
||||
assert messages[2].role == 'user'
|
||||
assert messages[2].content == 'current message'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_empty_messages(self):
|
||||
"""Empty messages list should return empty list."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config()
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = []
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert len(result.new_query.messages) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_single_message(self):
|
||||
"""Single message should be preserved."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config()
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='hello'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert len(result.new_query.messages) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_preserves_order(self):
|
||||
"""Truncation should preserve message order."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config(max_round=2)
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("current")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='user1'),
|
||||
provider_message.Message(role='assistant', content='asst1'),
|
||||
provider_message.Message(role='user', content='user2'),
|
||||
provider_message.Message(role='assistant', content='asst2'),
|
||||
provider_message.Message(role='user', content='user3'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
|
||||
messages = result.new_query.messages
|
||||
assert [(msg.role, msg.content) for msg in messages] == [
|
||||
('user', 'user2'),
|
||||
('assistant', 'asst2'),
|
||||
('user', 'user3'),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_max_round_one(self):
|
||||
"""max-round=1 should only keep last user message."""
|
||||
msgtrun = get_msgtrun_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = msgtrun.ConversationMessageTruncator(app)
|
||||
|
||||
pipeline_config = make_truncate_config(max_round=1)
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("current")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='old1'),
|
||||
provider_message.Message(role='assistant', content='old1_resp'),
|
||||
provider_message.Message(role='user', content='current'),
|
||||
]
|
||||
|
||||
result = await stage.process(query, 'ConversationMessageTruncator')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
messages = result.new_query.messages
|
||||
assert [(msg.role, msg.content) for msg in messages] == [('user', 'current')]
|
||||
|
||||
|
||||
class TestRoundTruncatorDirect:
|
||||
"""Direct tests for RoundTruncator class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_truncator_direct_process(self):
|
||||
"""Test RoundTruncator truncate method directly."""
|
||||
truncator_mod = get_truncator_module()
|
||||
|
||||
app = FakeApp()
|
||||
|
||||
# Get the RoundTruncator class from preregistered
|
||||
for trun_cls in truncator_mod.preregistered_truncators:
|
||||
if trun_cls.name == 'round':
|
||||
trun = trun_cls(app)
|
||||
break
|
||||
|
||||
query = text_query("hello")
|
||||
query.pipeline_config = make_truncate_config(max_round=3)
|
||||
query.messages = [
|
||||
provider_message.Message(role='user', content='m1'),
|
||||
provider_message.Message(role='assistant', content='r1'),
|
||||
provider_message.Message(role='user', content='m2'),
|
||||
provider_message.Message(role='assistant', content='r2'),
|
||||
provider_message.Message(role='user', content='hello'),
|
||||
]
|
||||
|
||||
result = await trun.truncate(query)
|
||||
|
||||
assert result is not None
|
||||
assert hasattr(result, 'messages')
|
||||
Reference in New Issue
Block a user