mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-28 00:14:21 +00:00
223 lines
7.9 KiB
Python
223 lines
7.9 KiB
Python
"""
|
|
PipelineManager unit tests
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, Mock
|
|
from importlib import import_module
|
|
|
|
|
|
def get_pipelinemgr_module():
|
|
return import_module('langbot.pkg.pipeline.pipelinemgr')
|
|
|
|
|
|
def get_stage_module():
|
|
return import_module('langbot.pkg.pipeline.stage')
|
|
|
|
|
|
def get_entities_module():
|
|
return import_module('langbot.pkg.pipeline.entities')
|
|
|
|
|
|
def get_persistence_pipeline_module():
|
|
return import_module('langbot.pkg.entity.persistence.pipeline')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pipeline_manager_initialize(mock_app):
|
|
"""Test pipeline manager initialization"""
|
|
pipelinemgr = get_pipelinemgr_module()
|
|
|
|
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
|
|
|
manager = pipelinemgr.PipelineManager(mock_app)
|
|
await manager.initialize()
|
|
|
|
assert manager.stage_dict is not None
|
|
assert len(manager.pipelines) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_pipeline(mock_app):
|
|
"""Test loading a single pipeline"""
|
|
pipelinemgr = get_pipelinemgr_module()
|
|
persistence_pipeline = get_persistence_pipeline_module()
|
|
|
|
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
|
|
|
manager = pipelinemgr.PipelineManager(mock_app)
|
|
await manager.initialize()
|
|
|
|
# Create test pipeline entity
|
|
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
|
pipeline_entity.uuid = 'test-uuid'
|
|
pipeline_entity.stages = []
|
|
pipeline_entity.config = {'test': 'config'}
|
|
pipeline_entity.extensions_preferences = {'plugins': []}
|
|
|
|
await manager.load_pipeline(pipeline_entity)
|
|
|
|
assert len(manager.pipelines) == 1
|
|
assert manager.pipelines[0].pipeline_entity.uuid == 'test-uuid'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_pipeline_by_uuid(mock_app):
|
|
"""Test getting pipeline by UUID"""
|
|
pipelinemgr = get_pipelinemgr_module()
|
|
persistence_pipeline = get_persistence_pipeline_module()
|
|
|
|
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
|
|
|
manager = pipelinemgr.PipelineManager(mock_app)
|
|
await manager.initialize()
|
|
|
|
# Create and add test pipeline
|
|
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
|
pipeline_entity.uuid = 'test-uuid'
|
|
pipeline_entity.stages = []
|
|
pipeline_entity.config = {}
|
|
pipeline_entity.extensions_preferences = {'plugins': []}
|
|
|
|
await manager.load_pipeline(pipeline_entity)
|
|
|
|
# Test retrieval
|
|
result = await manager.get_pipeline_by_uuid('test-uuid')
|
|
assert result is not None
|
|
assert result.pipeline_entity.uuid == 'test-uuid'
|
|
|
|
# Test non-existent UUID
|
|
result = await manager.get_pipeline_by_uuid('non-existent')
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_remove_pipeline(mock_app):
|
|
"""Test removing a pipeline"""
|
|
pipelinemgr = get_pipelinemgr_module()
|
|
persistence_pipeline = get_persistence_pipeline_module()
|
|
|
|
mock_app.persistence_mgr.execute_async = AsyncMock(return_value=Mock(all=Mock(return_value=[])))
|
|
|
|
manager = pipelinemgr.PipelineManager(mock_app)
|
|
await manager.initialize()
|
|
|
|
# Create and add test pipeline
|
|
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
|
pipeline_entity.uuid = 'test-uuid'
|
|
pipeline_entity.stages = []
|
|
pipeline_entity.config = {}
|
|
pipeline_entity.extensions_preferences = {'plugins': []}
|
|
|
|
await manager.load_pipeline(pipeline_entity)
|
|
assert len(manager.pipelines) == 1
|
|
|
|
# Remove pipeline
|
|
await manager.remove_pipeline('test-uuid')
|
|
assert len(manager.pipelines) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runtime_pipeline_execute(mock_app, sample_query):
|
|
"""Test runtime pipeline execution with real Pydantic models."""
|
|
pipelinemgr = get_pipelinemgr_module()
|
|
stage = get_stage_module()
|
|
persistence_pipeline = get_persistence_pipeline_module()
|
|
entities = get_entities_module()
|
|
|
|
# Create result using real Pydantic model (not Mock) to ensure validation
|
|
real_result = entities.StageProcessResult(
|
|
result_type=entities.ResultType.CONTINUE,
|
|
new_query=sample_query,
|
|
user_notice='',
|
|
console_notice='',
|
|
debug_notice='',
|
|
error_notice='',
|
|
)
|
|
|
|
mock_stage = Mock(spec=stage.PipelineStage)
|
|
mock_stage.process = AsyncMock(return_value=real_result)
|
|
|
|
# Create stage container
|
|
stage_container = pipelinemgr.StageInstContainer(inst_name='TestStage', inst=mock_stage)
|
|
|
|
# Create pipeline entity
|
|
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
|
pipeline_entity.config = sample_query.pipeline_config
|
|
pipeline_entity.extensions_preferences = {'plugins': []}
|
|
|
|
# Create runtime pipeline
|
|
runtime_pipeline = pipelinemgr.RuntimePipeline(mock_app, pipeline_entity, [stage_container])
|
|
|
|
# Mock plugin connector
|
|
event_ctx = Mock()
|
|
event_ctx.is_prevented_default = Mock(return_value=False)
|
|
mock_app.plugin_connector.emit_event = AsyncMock(return_value=event_ctx)
|
|
|
|
# Add query to cached_queries to prevent KeyError in finally block
|
|
mock_app.query_pool.cached_queries[sample_query.query_id] = sample_query
|
|
|
|
# Execute pipeline
|
|
await runtime_pipeline.run(sample_query)
|
|
|
|
# Verify stage was called
|
|
mock_stage.process.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_runtime_pipeline_marks_trace_error_when_stage_returns_error_notice(mock_app, sample_query):
|
|
"""Trace status follows handled stage errors, not only raised exceptions."""
|
|
pipelinemgr = get_pipelinemgr_module()
|
|
stage = get_stage_module()
|
|
persistence_pipeline = get_persistence_pipeline_module()
|
|
entities = get_entities_module()
|
|
|
|
error_result = entities.StageProcessResult(
|
|
result_type=entities.ResultType.INTERRUPT,
|
|
new_query=sample_query,
|
|
user_notice='',
|
|
console_notice='',
|
|
debug_notice='traceback',
|
|
error_notice='model request failed',
|
|
)
|
|
|
|
mock_stage = Mock(spec=stage.PipelineStage)
|
|
mock_stage.process = AsyncMock(return_value=error_result)
|
|
stage_container = pipelinemgr.StageInstContainer(inst_name='FailingStage', inst=mock_stage)
|
|
|
|
pipeline_entity = Mock(spec=persistence_pipeline.LegacyPipeline)
|
|
pipeline_entity.uuid = 'test-pipeline-uuid'
|
|
pipeline_entity.name = 'Test Pipeline'
|
|
pipeline_entity.config = sample_query.pipeline_config
|
|
pipeline_entity.extensions_preferences = {'plugins': []}
|
|
|
|
mock_app.bot_service = AsyncMock()
|
|
mock_app.bot_service.get_bot = AsyncMock(return_value={'name': 'Test Bot'})
|
|
mock_app.monitoring_service = AsyncMock()
|
|
mock_app.monitoring_service.record_message = AsyncMock(return_value='message-1')
|
|
mock_app.monitoring_service.update_session_activity = AsyncMock(return_value=True)
|
|
mock_app.monitoring_service.start_trace = AsyncMock(return_value='trace-1')
|
|
mock_app.monitoring_service.record_span = AsyncMock()
|
|
mock_app.monitoring_service.finish_trace = AsyncMock()
|
|
mock_app.monitoring_service.update_message_status = AsyncMock()
|
|
mock_app.monitoring_service.record_error = AsyncMock()
|
|
|
|
event_ctx = Mock()
|
|
event_ctx.is_prevented_default = Mock(return_value=False)
|
|
mock_app.plugin_connector.emit_event = AsyncMock(return_value=event_ctx)
|
|
mock_app.query_pool.cached_queries[sample_query.query_id] = sample_query
|
|
|
|
runtime_pipeline = pipelinemgr.RuntimePipeline(mock_app, pipeline_entity, [stage_container])
|
|
|
|
await runtime_pipeline.run(sample_query)
|
|
|
|
mock_app.monitoring_service.finish_trace.assert_awaited_once()
|
|
assert mock_app.monitoring_service.finish_trace.await_args.kwargs['status'] == 'error'
|
|
|
|
span_calls = mock_app.monitoring_service.record_span.await_args_list
|
|
stage_span_call = next(call for call in span_calls if call.kwargs['name'] == 'FailingStage')
|
|
root_span_call = next(call for call in span_calls if call.kwargs['kind'] == 'pipeline.query')
|
|
assert stage_span_call.kwargs['status'] == 'error'
|
|
assert stage_span_call.kwargs['error_message'] == 'model request failed'
|
|
assert root_span_call.kwargs['status'] == 'error'
|