From 7d94a3e8ddfc81fdc207be41e83404dbc2159e60 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <60681390+huanghuoguoguo@users.noreply.github.com> Date: Sat, 13 Jun 2026 16:47:56 +0800 Subject: [PATCH] Add plugin rerank invocation action --- src/langbot/pkg/plugin/handler.py | 29 +++++++++ .../unit_tests/plugin/test_handler_actions.py | 60 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index f5a8511e..c823f95a 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -514,6 +514,35 @@ class RuntimeConnectionHandler(handler.Handler): except Exception as e: return _make_rag_error_response(e, 'EmbeddingError', embedding_model_uuid=embedding_model_uuid) + @self.action(PluginToRuntimeAction.INVOKE_RERANK) + async def invoke_rerank(data: dict[str, Any]) -> handler.ActionResponse: + rerank_model_uuid = data['rerank_model_uuid'] + query = data['query'] + documents = data['documents'] + top_k = data.get('top_k') + extra_args = data.get('extra_args', {}) + + try: + rerank_model = await self.ap.model_mgr.get_rerank_model_by_uuid(rerank_model_uuid) + except ValueError: + return handler.ActionResponse.error( + message=f'Rerank model with rerank_model_uuid {rerank_model_uuid} not found', + ) + + try: + scores = await rerank_model.provider.invoke_rerank( + model=rerank_model, + query=query, + documents=documents[:64], + extra_args=extra_args, + ) + scored = sorted(scores, key=lambda x: x.get('relevance_score', 0), reverse=True) + if top_k is not None: + scored = scored[: int(top_k)] + return handler.ActionResponse.success(data={'results': scored}) + except Exception as e: + return _make_rag_error_response(e, 'RerankError', rerank_model_uuid=rerank_model_uuid) + @self.action(PluginToRuntimeAction.VECTOR_UPSERT) async def vector_upsert(data: dict[str, Any]) -> handler.ActionResponse: collection_id = data['collection_id'] diff --git a/tests/unit_tests/plugin/test_handler_actions.py b/tests/unit_tests/plugin/test_handler_actions.py index 81bc7570..88ed66b3 100644 --- a/tests/unit_tests/plugin/test_handler_actions.py +++ b/tests/unit_tests/plugin/test_handler_actions.py @@ -27,6 +27,66 @@ def compiled_params(statement): return statement.compile().params +class TestRagRerankAction: + """Tests for RAG rerank action handler.""" + + @pytest.fixture + def app(self): + mock_app = Mock() + mock_app.model_mgr = Mock() + mock_app.logger = Mock() + return mock_app + + @pytest.mark.asyncio + async def test_invokes_rerank_model_and_sorts_scores(self, app): + """Rerank action uses the selected model and returns top scores.""" + provider = Mock() + provider.invoke_rerank = AsyncMock( + return_value=[ + {'index': 0, 'relevance_score': 0.2}, + {'index': 1, 'relevance_score': 0.9}, + ] + ) + rerank_model = SimpleNamespace(provider=provider) + app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=rerank_model) + runtime_handler = make_handler(app) + + response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value]({ + 'rerank_model_uuid': 'rerank-1', + 'query': 'hello', + 'documents': ['a', 'b'], + 'top_k': 1, + 'extra_args': {'return_documents': False}, + }) + + assert response.code == 0 + assert response.data['results'] == [{'index': 1, 'relevance_score': 0.9}] + app.model_mgr.get_rerank_model_by_uuid.assert_awaited_once_with('rerank-1') + provider.invoke_rerank.assert_awaited_once_with( + model=rerank_model, + query='hello', + documents=['a', 'b'], + extra_args={'return_documents': False}, + ) + + @pytest.mark.asyncio + async def test_returns_error_when_rerank_model_missing(self, app): + """Missing rerank model returns an action error.""" + app.model_mgr.get_rerank_model_by_uuid = AsyncMock( + side_effect=ValueError('not found') + ) + runtime_handler = make_handler(app) + + response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value]({ + 'rerank_model_uuid': 'missing', + 'query': 'hello', + 'documents': ['a'], + }) + + assert response.code != 0 + assert 'Rerank model with rerank_model_uuid missing not found' in response.message + + class TestInitializePluginSettings: """Tests for initialize_plugin_settings action handler."""