test: format test suite

This commit is contained in:
huanghuoguoguo
2026-06-16 11:13:05 +08:00
parent 1ae5aacc00
commit ff0c5a6f0a
92 changed files with 1658 additions and 1713 deletions
+14 -14
View File
@@ -49,7 +49,7 @@ class TestPendingMessage:
"""PendingMessage should be created with correct fields."""
aggregator = get_aggregator_module()
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -88,7 +88,7 @@ class TestSessionBuffer:
"""SessionBuffer should accept initial messages."""
aggregator = get_aggregator_module()
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -309,7 +309,7 @@ class TestMessageAggregatorAddMessage:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -348,7 +348,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -387,7 +387,7 @@ class TestMessageAggregatorAddMessage:
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -419,7 +419,7 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -445,8 +445,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain1 = text_chain("hello")
chain2 = text_chain("world")
chain1 = text_chain('hello')
chain2 = text_chain('world')
event = friend_message_event(chain1)
adapter = mock_adapter()
@@ -476,8 +476,8 @@ class TestMessageAggregatorMerge:
# Should contain both messages with separator
merged_str = str(merged.message_chain)
assert "hello" in merged_str
assert "world" in merged_str
assert 'hello' in merged_str
assert 'world' in merged_str
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
@@ -486,8 +486,8 @@ class TestMessageAggregatorMerge:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain1 = text_chain("first")
chain2 = text_chain("second")
chain1 = text_chain('first')
chain2 = text_chain('second')
event = friend_message_event(chain1)
adapter = mock_adapter()
@@ -545,7 +545,7 @@ class TestMessageAggregatorFlush:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
@@ -597,7 +597,7 @@ class TestMessageAggregatorFlushAll:
app = make_aggregator_app()
agg = aggregator.MessageAggregator(app)
chain = text_chain("hello")
chain = text_chain('hello')
event = friend_message_event(chain)
adapter = mock_adapter()
+34 -4
View File
@@ -15,6 +15,7 @@ from tests.factories import FakeApp
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -36,9 +37,11 @@ def mock_circular_import_chain():
# Create a default runner that yields a simple response
class DefaultRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='fake response')
@@ -70,9 +73,12 @@ def mock_event_ctx():
@pytest.fixture
def set_runner():
"""Factory fixture to set a custom runner for tests."""
def _set_runner(runner_class):
import sys
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
return _set_runner
@@ -87,6 +93,7 @@ def get_chat_handler():
global _chat_handler_module
if _chat_handler_module is None:
from importlib import import_module
_chat_handler_module = import_module('langbot.pkg.pipeline.process.handlers.chat')
return _chat_handler_module
@@ -96,12 +103,14 @@ def get_entities():
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL ChatMessageHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestChatMessageHandlerReal:
"""Tests for real ChatMessageHandler class."""
@@ -188,9 +197,11 @@ class TestChatMessageHandlerReal:
class QuickRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='ok')
@@ -222,9 +233,11 @@ class TestChatMessageHandlerReal:
class SingleRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='response')
@@ -262,9 +275,11 @@ class TestChatHandlerStreaming:
class StreamRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield MessageChunk(role='assistant', content='Hello', is_final=False)
yield MessageChunk(role='assistant', content=' World', is_final=True)
@@ -303,14 +318,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-hint', 'failure-hint': 'Request failed.'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class FailingRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('API error')
yield
@@ -346,14 +366,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'show-error'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class ErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise ValueError('Custom error')
yield
@@ -386,14 +411,19 @@ class TestChatHandlerExceptions:
query.pipeline_config = {
'output': {'misc': {'exception-handling': 'hide'}},
'ai': {'runner': {'runner': 'local-agent'}, 'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}}},
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'prompt': 'default', 'model': {'primary': 'test'}},
},
}
class HideErrorRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
raise RuntimeError('hidden')
yield
@@ -433,4 +463,4 @@ class TestChatHandlerHelper:
chat = get_chat_handler()
handler = chat.ChatMessageHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result
assert '...' in result
+22 -24
View File
@@ -67,7 +67,11 @@ def make_pipeline_config(**overrides):
for key, value in overrides.items():
if key in base_config and isinstance(base_config[key], dict) and isinstance(value, dict):
for sub_key, sub_value in value.items():
if sub_key in base_config[key] and isinstance(base_config[key][sub_key], dict) and isinstance(sub_value, dict):
if (
sub_key in base_config[key]
and isinstance(base_config[key][sub_key], dict)
and isinstance(sub_value, dict)
):
base_config[key][sub_key].update(sub_value)
else:
base_config[key][sub_key] = sub_value
@@ -141,7 +145,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello world")
query = text_query('hello world')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -163,7 +167,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
# Empty message chain
query = text_query("")
query = text_query('')
query.message_chain = platform_message.MessageChain([])
query.pipeline_config = pipeline_config
@@ -185,7 +189,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query(" ") # Only whitespace
query = text_query(' ') # Only whitespace
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -234,7 +238,7 @@ class TestPreContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello world")
query = text_query('hello world')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -266,7 +270,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("/help me")
query = text_query('/help me')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -294,7 +298,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("http://example.com")
query = text_query('http://example.com')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -322,7 +326,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("normal message")
query = text_query('normal message')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -343,7 +347,7 @@ class TestContentIgnoreFilter:
await stage.initialize(pipeline_config)
query = text_query("/help me")
query = text_query('/help me')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -368,12 +372,10 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Add a response message
query.resp_messages = [
provider_message.Message(role='assistant', content='Hello back!')
]
query.resp_messages = [provider_message.Message(role='assistant', content='Hello back!')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -398,11 +400,9 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
provider_message.Message(role='assistant', content='Response')
]
query.resp_messages = [provider_message.Message(role='assistant', content='Response')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -422,7 +422,7 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Non-string content - use model_construct to bypass validation
# The actual content type could be a list of ContentElement objects
@@ -450,11 +450,9 @@ class TestPostContentFilter:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
provider_message.Message(role='assistant', content='')
]
query.resp_messages = [provider_message.Message(role='assistant', content='')]
result = await stage.process(query, 'PostContentFilterStage')
@@ -476,7 +474,7 @@ class TestContentFilterStageInvalidName:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
with pytest.raises(ValueError, match='未知的 stage_inst_name'):
@@ -506,7 +504,7 @@ class TestContentIgnoreFilterDirect:
await stage.initialize(pipeline_config)
query = text_query("normal message without prefix")
query = text_query('normal message without prefix')
query.pipeline_config = pipeline_config
result = await stage.process(query, 'PreContentFilterStage')
@@ -15,6 +15,7 @@ from tests.factories import FakeApp, command_query
# ============== FIXTURE USING IMPORT ISOLATION UTILITY ==============
@pytest.fixture(scope='module')
def mock_circular_import_chain():
"""
@@ -56,6 +57,7 @@ def mock_event_ctx():
@pytest.fixture
def mock_execute_factory():
"""Factory fixture to create mock cmd_mgr.execute generators."""
def _create_execute(
text: str | None = 'ok',
error: str | None = None,
@@ -71,7 +73,9 @@ def mock_execute_factory():
ret.image_base64 = image_base64
ret.file_url = file_url
yield ret
return mock_execute
return _create_execute
@@ -86,6 +90,7 @@ def get_command_handler():
global _command_handler_module
if _command_handler_module is None:
from importlib import import_module
_command_handler_module = import_module('langbot.pkg.pipeline.process.handlers.command')
return _command_handler_module
@@ -95,12 +100,14 @@ def get_entities():
global _entities_module
if _entities_module is None:
from importlib import import_module
_entities_module = import_module('langbot.pkg.pipeline.entities')
return _entities_module
# ============== REAL CommandHandler Tests ==============
@pytest.mark.usefixtures('mock_circular_import_chain')
class TestCommandHandlerReal:
"""Tests for real CommandHandler class."""
@@ -127,6 +134,7 @@ class TestCommandHandlerReal:
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
executed_commands = []
async def track_execute(command_text, full_command_text, query, session):
executed_commands.append(command_text)
ret = Mock()
@@ -334,8 +342,7 @@ class TestCommandHandlerReal:
command = get_command_handler()
fake_app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
fake_app.cmd_mgr.execute = mock_execute_factory(
text='Here is the image:',
image_url='https://example.com/image.png'
text='Here is the image:', image_url='https://example.com/image.png'
)
handler = command.CommandHandler(fake_app)
@@ -393,4 +400,4 @@ class TestCommandHandlerHelper:
command = get_command_handler()
handler = command.CommandHandler(fake_app)
result = handler.cut_str('first line\nsecond line')
assert '...' in result
assert '...' in result
+19 -27
View File
@@ -126,11 +126,9 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="very long response")])
]
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='very long response')])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -151,11 +149,9 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text="short response")])
]
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text='short response')])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -179,14 +175,13 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Non-Plain component (Image)
query.resp_message_chain = [
platform_message.MessageChain([
platform_message.Plain(text="short"),
platform_message.Image(url="https://example.com/img.png")
])
platform_message.MessageChain(
[platform_message.Plain(text='short'), platform_message.Image(url='https://example.com/img.png')]
)
]
result = await stage.process(query, 'LongTextProcessStage')
@@ -213,7 +208,7 @@ class TestLongTextProcessStageProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -232,7 +227,7 @@ class TestLongTextProcessStageProcess:
stage = longtext.LongTextProcessStage(app)
stage.strategy_impl = AsyncMock()
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
query.resp_message_chain = []
@@ -242,6 +237,7 @@ class TestLongTextProcessStageProcess:
assert result.new_query is query
stage.strategy_impl.process.assert_not_called()
class TestForwardStrategy:
"""Tests for ForwardComponentStrategy."""
@@ -260,7 +256,7 @@ class TestForwardStrategy:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Create a mock adapter with bot_account_id
mock_adapter = Mock()
@@ -268,10 +264,8 @@ class TestForwardStrategy:
query.adapter = mock_adapter
# Long text exceeding threshold
long_text = "This is a very long response that exceeds the threshold"
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=long_text)])
]
long_text = 'This is a very long response that exceeds the threshold'
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=long_text)])]
result = await stage.process(query, 'LongTextProcessStage')
@@ -297,13 +291,13 @@ class TestForwardStrategy:
await strat.initialize()
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_longtext_config()
mock_adapter = Mock()
mock_adapter.bot_account_id = '12345'
query.adapter = mock_adapter
components = await strat.process("test message", query)
components = await strat.process('test message', query)
assert len(components) == 1
assert isinstance(components[0], platform_message.Forward)
@@ -326,14 +320,12 @@ class TestLongTextThreshold:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
# Text below threshold
short_text = "x" * (threshold - 1)
query.resp_message_chain = [
platform_message.MessageChain([platform_message.Plain(text=short_text)])
]
short_text = 'x' * (threshold - 1)
query.resp_message_chain = [platform_message.MessageChain([platform_message.Plain(text=short_text)])]
result = await stage.process(query, 'LongTextProcessStage')
+7 -7
View File
@@ -115,7 +115,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
# Create query with 3 messages (within limit)
query = text_query("current message")
query = text_query('current message')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
@@ -154,7 +154,7 @@ class TestRoundTruncatorProcess:
# Create query with many messages exceeding limit
# 7 messages = 3 full rounds + 1 current user
query = text_query("current message")
query = text_query('current message')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='message 1'),
@@ -194,7 +194,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.messages = []
@@ -216,7 +216,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='hello'),
@@ -240,7 +240,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("current")
query = text_query('current')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='user1'),
@@ -274,7 +274,7 @@ class TestRoundTruncatorProcess:
await stage.initialize(pipeline_config)
query = text_query("current")
query = text_query('current')
query.pipeline_config = pipeline_config
query.messages = [
provider_message.Message(role='user', content='old1'),
@@ -305,7 +305,7 @@ class TestRoundTruncatorDirect:
trun = trun_cls(app)
break
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = make_truncate_config(max_round=3)
query.messages = [
provider_message.Message(role='user', content='m1'),
+15 -12
View File
@@ -78,7 +78,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello world")
query = text_query('hello world')
result = await stage.process(query, 'PreProcessor')
@@ -113,7 +113,7 @@ class TestPreProcessorNormalText:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("test message")
query = text_query('test message')
result = await stage.process(query, 'PreProcessor')
@@ -194,13 +194,16 @@ class TestPreProcessorImageSegment:
stage = preproc.PreProcessor(app)
# Image query with base64
query = image_query(text="look at this", url=None)
query = image_query(text='look at this', url=None)
# Set base64 on the image component
import langbot_plugin.api.entities.builtin.platform.message as platform_message
chain = platform_message.MessageChain([
platform_message.Plain(text="look at this"),
platform_message.Image(base64="data:image/png;base64,abc123"),
])
chain = platform_message.MessageChain(
[
platform_message.Plain(text='look at this'),
platform_message.Image(base64='data:image/png;base64,abc123'),
]
)
query.message_chain = chain
result = await stage.process(query, 'PreProcessor')
@@ -238,7 +241,7 @@ class TestPreProcessorImageSegment:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = image_query(text="describe this")
query = image_query(text='describe this')
result = await stage.process(query, 'PreProcessor')
@@ -276,7 +279,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello")
query = text_query('hello')
# Set pipeline config with primary model
query.pipeline_config = {
@@ -335,7 +338,7 @@ class TestPreProcessorModelSelection:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = {
'ai': {
@@ -384,7 +387,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = text_query("hello", sender_id=67890)
query = text_query('hello', sender_id=67890)
result = await stage.process(query, 'PreProcessor')
@@ -421,7 +424,7 @@ class TestPreProcessorVariables:
app.plugin_connector.emit_event = AsyncMock(return_value=mock_event_ctx)
stage = preproc.PreProcessor(app)
query = group_text_query("hello", group_id=99999)
query = group_text_query('hello', group_id=99999)
result = await stage.process(query, 'PreProcessor')
+18 -58
View File
@@ -46,7 +46,7 @@ class TestFixedWindowAlgo:
'safety': {
'rate-limit': {
'window-length': 60, # 60 seconds window
'limitation': 10, # 10 requests per window
'limitation': 10, # 10 requests per window
'strategy': 'drop',
}
}
@@ -75,11 +75,9 @@ class TestFixedWindowAlgo:
# Make requests within limit
for i in range(10):
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345'
)
assert result is True, f"Request {i+1} should be allowed"
assert result is True, f'Request {i + 1} should be allowed'
@pytest.mark.asyncio
async def test_fixedwin_exceeds_limit_drop_strategy(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -91,20 +89,12 @@ class TestFixedWindowAlgo:
# Exhaust the limit
for i in range(10):
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Next request should be denied
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
result = await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
assert result is False, "Request exceeding limit should be denied"
assert result is False, 'Request exceeding limit should be denied'
@pytest.mark.asyncio
async def test_fixedwin_different_sessions_isolated(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -116,20 +106,14 @@ class TestFixedWindowAlgo:
# Exhaust limit for session 1
for i in range(10):
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session1'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session1')
# Session 2 should still have its own limit
result = await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'session2'
sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, 'session2'
)
assert result is True, "Different session should have independent limit"
assert result is True, 'Different session should have independent limit'
@pytest.mark.asyncio
async def test_fixedwin_limit_one_request(self, mock_app_for_algo, sample_query):
@@ -150,19 +134,11 @@ class TestFixedWindowAlgo:
await algo.initialize()
# First request allowed
result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
assert result1 is True
# Second request denied
result2 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'12345'
)
result2 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, '12345')
assert result2 is False
@pytest.mark.asyncio
@@ -174,11 +150,7 @@ class TestFixedWindowAlgo:
await algo.initialize()
# First request creates container
await algo.require_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.require_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Key format: 'LauncherTypes.PERSON_12345' (enum string representation)
expected_key = 'LauncherTypes.PERSON_12345'
@@ -230,7 +202,7 @@ class TestFixedWindowAlgo:
# New request should be allowed (new window)
result = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'test')
assert result is True, "New window should allow new requests"
assert result is True, 'New window should allow new requests'
@pytest.mark.asyncio
async def test_fixedwin_wait_strategy_blocks_until_next_window(self, mock_app_for_algo, sample_query):
@@ -256,29 +228,21 @@ class TestFixedWindowAlgo:
# First request allowed
start_time = time.time()
result1 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
result1 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
assert result1 is True
# Exhaust limit
await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
# Third request should wait and then succeed
result3 = await algo.require_access(
sample_query,
provider_session.LauncherTypes.PERSON,
'wait_test'
)
result3 = await algo.require_access(sample_query, provider_session.LauncherTypes.PERSON, 'wait_test')
elapsed = time.time() - start_time
assert result3 is True, "After wait, request should succeed"
assert result3 is True, 'After wait, request should succeed'
# Should have waited approximately until next window
# With 1-second window, elapsed should be > 0.5 second (allowing for timing variance)
# Note: This is a timing-sensitive test, so we use a generous tolerance
assert elapsed >= 0.5, f"Should have waited for next window, elapsed={elapsed:.2f}s"
assert elapsed >= 0.5, f'Should have waited for next window, elapsed={elapsed:.2f}s'
@pytest.mark.asyncio
async def test_fixedwin_release_access(self, mock_app_for_algo, sample_query_with_rate_limit):
@@ -289,11 +253,7 @@ class TestFixedWindowAlgo:
await algo.initialize()
# release_access is empty in current implementation
await algo.release_access(
sample_query_with_rate_limit,
provider_session.LauncherTypes.PERSON,
'12345'
)
await algo.release_access(sample_query_with_rate_limit, provider_session.LauncherTypes.PERSON, '12345')
# Should not raise or change state
assert 'person_12345' not in algo.containers
+25 -27
View File
@@ -55,7 +55,7 @@ def make_session():
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
use_prompt_name="default",
use_prompt_name='default',
using_conversation=None,
conversations=[],
)
@@ -93,11 +93,9 @@ class TestResponseWrapperMessageChain:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_messages = [
platform_message.MessageChain([platform_message.Plain(text="response")])
]
query.resp_messages = [platform_message.MessageChain([platform_message.Plain(text='response')])]
query.resp_message_chain = []
results = []
@@ -125,7 +123,7 @@ class TestResponseWrapperCommand:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -133,7 +131,7 @@ class TestResponseWrapperCommand:
command_resp = Mock()
command_resp.role = 'command'
command_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Help info")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Help info')])
)
query.resp_messages = [command_resp]
@@ -163,7 +161,7 @@ class TestResponseWrapperPlugin:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -171,7 +169,7 @@ class TestResponseWrapperPlugin:
plugin_resp = Mock()
plugin_resp.role = 'plugin'
plugin_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Plugin response")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Plugin response')])
)
query.resp_messages = [plugin_resp]
@@ -211,17 +209,17 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello back!"
assistant_resp.content = 'Hello back!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello back!")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello back!')])
)
query.resp_messages = [assistant_resp]
@@ -247,7 +245,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -292,7 +290,7 @@ class TestResponseWrapperAssistant:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
@@ -303,10 +301,10 @@ class TestResponseWrapperAssistant:
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Processing..."
assistant_resp.content = 'Processing...'
assistant_resp.tool_calls = [mock_tool_call]
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Processing...")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Processing...')])
)
query.resp_messages = [assistant_resp]
@@ -346,17 +344,17 @@ class TestResponseWrapperInterrupt:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response with content
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello!"
assistant_resp.content = 'Hello!'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello!")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello!')])
)
query.resp_messages = [assistant_resp]
@@ -384,7 +382,7 @@ class TestResponseWrapperCustomReply:
app.sess_mgr.get_session = AsyncMock(return_value=session)
# Mock plugin connector with custom reply
custom_chain = platform_message.MessageChain([platform_message.Plain(text="Custom reply")])
custom_chain = platform_message.MessageChain([platform_message.Plain(text='Custom reply')])
mock_event_ctx = Mock()
mock_event_ctx.is_prevented_default = Mock(return_value=False)
mock_event_ctx.event = Mock()
@@ -397,17 +395,17 @@ class TestResponseWrapperCustomReply:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Default reply"
assistant_resp.content = 'Default reply'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Default reply")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Default reply')])
)
query.resp_messages = [assistant_resp]
@@ -421,7 +419,7 @@ class TestResponseWrapperCustomReply:
assert len(results[0].new_query.resp_message_chain) == 1
# Should be the custom chain
chain = results[0].new_query.resp_message_chain[0]
assert "Custom reply" in str(chain)
assert 'Custom reply' in str(chain)
class TestResponseWrapperVariables:
@@ -452,7 +450,7 @@ class TestResponseWrapperVariables:
await stage.initialize(pipeline_config)
query = text_query("hello")
query = text_query('hello')
query.pipeline_config = pipeline_config
query.resp_message_chain = []
query.variables['_pipeline_bound_plugins'] = ['plugin1', 'plugin2']
@@ -460,10 +458,10 @@ class TestResponseWrapperVariables:
# Create assistant response
assistant_resp = Mock()
assistant_resp.role = 'assistant'
assistant_resp.content = "Hello"
assistant_resp.content = 'Hello'
assistant_resp.tool_calls = None
assistant_resp.get_content_platform_message_chain = Mock(
return_value=platform_message.MessageChain([platform_message.Plain(text="Hello")])
return_value=platform_message.MessageChain([platform_message.Plain(text='Hello')])
)
query.resp_messages = [assistant_resp]