Files
LangBot/tests/unit_tests/pipeline/test_pool.py
2026-05-16 10:30:17 +08:00

291 lines
8.6 KiB
Python

"""
Unit tests for QueryPool.
Tests query management, ID generation, and async context handling.
"""
from __future__ import annotations
import pytest
from unittest.mock import Mock, patch
from langbot.pkg.pipeline.pool import QueryPool
pytestmark = pytest.mark.asyncio
class TestQueryPoolInit:
"""Tests for QueryPool initialization."""
def test_init_creates_empty_pool(self):
"""QueryPool initializes with empty lists."""
pool = QueryPool()
assert pool.queries == []
assert pool.cached_queries == {}
assert pool.query_id_counter == 0
assert pool.pool_lock is not None
assert pool.condition is not None
def test_init_counter_starts_at_zero(self):
"""Counter starts at zero."""
pool = QueryPool()
assert pool.query_id_counter == 0
class TestQueryPoolAddQuery:
"""Tests for add_query method."""
async def test_add_query_adds_query_with_id(self):
"""add_query creates, stores, and caches a Query with the correct ID."""
pool = QueryPool()
# Mock Query creation
mock_query = Mock()
mock_query.query_id = 0
mock_query.bot_uuid = 'test-bot-uuid'
mock_query.launcher_id = 12345
with patch('langbot.pkg.pipeline.pool.pipeline_query.Query') as MockQuery:
MockQuery.return_value = mock_query
await pool.add_query(
bot_uuid='test-bot-uuid',
launcher_type=Mock(),
launcher_id=12345,
sender_id=12345,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
)
# Query is added to list and cache
assert pool.queries[0] is mock_query
assert pool.cached_queries[0] is mock_query
assert mock_query.query_id == 0
async def test_add_query_increments_counter(self):
"""Each add_query increments the counter."""
pool = QueryPool()
mock_query1 = Mock()
mock_query1.query_id = 0
mock_query2 = Mock()
mock_query2.query_id = 1
with patch('langbot.pkg.pipeline.pool.pipeline_query.Query') as MockQuery:
MockQuery.side_effect = [mock_query1, mock_query2]
await pool.add_query(
bot_uuid='bot1',
launcher_type=Mock(),
launcher_id=1,
sender_id=1,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
)
await pool.add_query(
bot_uuid='bot2',
launcher_type=Mock(),
launcher_id=2,
sender_id=2,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
)
assert pool.query_id_counter == 2
assert pool.queries[0].query_id == 0
assert pool.queries[1].query_id == 1
async def test_add_query_appends_to_list(self):
"""Query is appended to queries list."""
pool = QueryPool()
mock_query = Mock()
mock_query.query_id = 0
with patch('langbot.pkg.pipeline.pool.pipeline_query.Query') as MockQuery:
MockQuery.return_value = mock_query
await pool.add_query(
bot_uuid='bot1',
launcher_type=Mock(),
launcher_id=1,
sender_id=1,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
)
assert len(pool.queries) == 1
assert pool.queries[0] is mock_query
async def test_add_query_caches_query(self):
"""Query is cached by query_id."""
pool = QueryPool()
mock_query = Mock()
mock_query.query_id = 0
with patch('langbot.pkg.pipeline.pool.pipeline_query.Query') as MockQuery:
MockQuery.return_value = mock_query
await pool.add_query(
bot_uuid='bot1',
launcher_type=Mock(),
launcher_id=1,
sender_id=1,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
)
assert 0 in pool.cached_queries
assert pool.cached_queries[0] is mock_query
async def test_add_query_with_pipeline_uuid(self):
"""Query can have pipeline_uuid set."""
pool = QueryPool()
mock_query = Mock()
mock_query.query_id = 0
mock_query.pipeline_uuid = 'test-pipeline-uuid'
with patch('langbot.pkg.pipeline.pool.pipeline_query.Query') as MockQuery:
MockQuery.return_value = mock_query
await pool.add_query(
bot_uuid='bot1',
launcher_type=Mock(),
launcher_id=1,
sender_id=1,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
pipeline_uuid='test-pipeline-uuid',
)
# Verify pipeline_uuid was passed to Query constructor
call_kwargs = MockQuery.call_args[1]
assert call_kwargs['pipeline_uuid'] == 'test-pipeline-uuid'
async def test_add_query_sets_routed_by_rule_variable(self):
"""Query has _routed_by_rule variable."""
pool = QueryPool()
mock_query = Mock()
mock_query.query_id = 0
mock_query.variables = {'_routed_by_rule': True}
with patch('langbot.pkg.pipeline.pool.pipeline_query.Query') as MockQuery:
MockQuery.return_value = mock_query
await pool.add_query(
bot_uuid='bot1',
launcher_type=Mock(),
launcher_id=1,
sender_id=1,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
routed_by_rule=True,
)
# Verify variables includes _routed_by_rule
call_kwargs = MockQuery.call_args[1]
assert call_kwargs['variables']['_routed_by_rule'] is True
async def test_add_query_notifier_condition(self):
"""add_query notifies waiting consumers."""
pool = QueryPool()
mock_query = Mock()
mock_query.query_id = 0
with patch('langbot.pkg.pipeline.pool.pipeline_query.Query') as MockQuery:
MockQuery.return_value = mock_query
# Track if notify_all was called
original_notify = pool.condition.notify_all
notify_called = []
def mock_notify():
notify_called.append(True)
return original_notify()
pool.condition.notify_all = mock_notify
await pool.add_query(
bot_uuid='bot1',
launcher_type=Mock(),
launcher_id=1,
sender_id=1,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
)
assert len(notify_called) == 1
class TestQueryPoolContext:
"""Tests for async context manager."""
async def test_aenter_acquires_lock(self):
"""__aenter__ acquires the pool lock."""
pool = QueryPool()
async with pool as p:
# Lock is acquired
assert pool.pool_lock.locked()
assert p is pool
async def test_aexit_releases_lock(self):
"""__aexit__ releases the pool lock."""
pool = QueryPool()
async with pool:
pass
# Lock is released after context exit
assert not pool.pool_lock.locked()
class TestQueryPoolEdgeCases:
"""Tests for edge cases."""
async def test_multiple_queries_cached_correctly(self):
"""Multiple queries are cached separately."""
pool = QueryPool()
mock_queries = []
for i in range(5):
q = Mock()
q.query_id = i
mock_queries.append(q)
with patch('langbot.pkg.pipeline.pool.pipeline_query.Query') as MockQuery:
MockQuery.side_effect = mock_queries
for i in range(5):
await pool.add_query(
bot_uuid=f'bot{i}',
launcher_type=Mock(),
launcher_id=i,
sender_id=i,
message_event=Mock(),
message_chain=Mock(),
adapter=Mock(),
)
# All cached
assert len(pool.cached_queries) == 5
# Each query is cached by its ID
for i in range(5):
assert pool.cached_queries[i] is mock_queries[i]