diff --git a/src/langbot/pkg/rag/knowledge/kbmgr.py b/src/langbot/pkg/rag/knowledge/kbmgr.py index ed71dafa..71d040bc 100644 --- a/src/langbot/pkg/rag/knowledge/kbmgr.py +++ b/src/langbot/pkg/rag/knowledge/kbmgr.py @@ -1,6 +1,7 @@ from __future__ import annotations import mimetypes import os.path +import time import traceback import uuid import zipfile @@ -341,6 +342,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface): filters = settings.pop('filters', {}) trace_context = settings.pop('_trace_context', None) host_span_started_at = self._utc_now() + host_span_started = time.perf_counter() host_span_id = None if trace_context and trace_context.get('trace_id'): host_parent_span_id = trace_context.get('parent_span_id') @@ -380,6 +382,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface): trace_context=trace_context, host_span_id=host_span_id, started_at=host_span_started_at, + duration=int((time.perf_counter() - host_span_started) * 1000), plugin_id=plugin_id, result={ 'results': [], @@ -395,6 +398,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface): trace_context=trace_context, host_span_id=host_span_id, started_at=host_span_started_at, + duration=int((time.perf_counter() - host_span_started) * 1000), plugin_id=plugin_id, result=result, ) @@ -405,6 +409,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface): trace_context: dict[str, Any], host_span_id: str | None, started_at: datetime.datetime, + duration: int, plugin_id: str, result: dict[str, Any], ) -> None: @@ -428,7 +433,7 @@ class RuntimeKnowledgeBase(KnowledgeBaseInterface): kind='rag.retrieval', status=metadata.get('status', 'success'), started_at=started_at, - duration=metadata.get('duration_ms'), + duration=duration, message_id=trace_context.get('message_id'), session_id=trace_context.get('session_id'), bot_id=trace_context.get('bot_id'), diff --git a/tests/unit_tests/rag/test_kbmgr.py b/tests/unit_tests/rag/test_kbmgr.py index a1a16118..21bdb51f 100644 --- a/tests/unit_tests/rag/test_kbmgr.py +++ b/tests/unit_tests/rag/test_kbmgr.py @@ -407,6 +407,32 @@ class TestRuntimeKnowledgeBaseRetrieve: call_args = mock_app.plugin_connector.call_rag_retrieve.call_args assert call_args[0][1]['retrieval_settings']['top_k'] == 5 + @pytest.mark.asyncio + async def test_retrieve_records_host_rag_duration(self, monkeypatch): + """Test host RAG span duration is measured even if plugin omits it.""" + rag_module = get_rag_module() + mock_app = create_mock_app() + mock_app.monitoring_service = AsyncMock() + mock_kb = create_mock_kb_entity() + mock_app.plugin_connector.call_rag_retrieve = AsyncMock( + return_value={'results': [], 'metadata': {'status': 'success'}} + ) + monkeypatch.setattr(rag_module.time, 'perf_counter', Mock(side_effect=[10.0, 10.25])) + + runtime_kb = rag_module.RuntimeKnowledgeBase(mock_app, mock_kb) + + await runtime_kb._retrieve( + 'query text', + { + '_trace_context': { + 'trace_id': 'trace-1', + 'parent_span_id': 'span-root', + } + }, + ) + + assert mock_app.monitoring_service.record_span.await_args.kwargs['duration'] == 250 + @pytest.mark.asyncio async def test_retrieve_converts_dict_to_entry(self): """Test that dict results are converted to RetrievalResultEntry."""