mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-14 09:46:03 +00:00
Compare commits
1 Commits
fix/litell
...
feat/rag-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d94a3e8dd |
@@ -514,6 +514,35 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _make_rag_error_response(e, 'EmbeddingError', embedding_model_uuid=embedding_model_uuid)
|
return _make_rag_error_response(e, 'EmbeddingError', embedding_model_uuid=embedding_model_uuid)
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.INVOKE_RERANK)
|
||||||
|
async def invoke_rerank(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
rerank_model_uuid = data['rerank_model_uuid']
|
||||||
|
query = data['query']
|
||||||
|
documents = data['documents']
|
||||||
|
top_k = data.get('top_k')
|
||||||
|
extra_args = data.get('extra_args', {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
rerank_model = await self.ap.model_mgr.get_rerank_model_by_uuid(rerank_model_uuid)
|
||||||
|
except ValueError:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Rerank model with rerank_model_uuid {rerank_model_uuid} not found',
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
scores = await rerank_model.provider.invoke_rerank(
|
||||||
|
model=rerank_model,
|
||||||
|
query=query,
|
||||||
|
documents=documents[:64],
|
||||||
|
extra_args=extra_args,
|
||||||
|
)
|
||||||
|
scored = sorted(scores, key=lambda x: x.get('relevance_score', 0), reverse=True)
|
||||||
|
if top_k is not None:
|
||||||
|
scored = scored[: int(top_k)]
|
||||||
|
return handler.ActionResponse.success(data={'results': scored})
|
||||||
|
except Exception as e:
|
||||||
|
return _make_rag_error_response(e, 'RerankError', rerank_model_uuid=rerank_model_uuid)
|
||||||
|
|
||||||
@self.action(PluginToRuntimeAction.VECTOR_UPSERT)
|
@self.action(PluginToRuntimeAction.VECTOR_UPSERT)
|
||||||
async def vector_upsert(data: dict[str, Any]) -> handler.ActionResponse:
|
async def vector_upsert(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
collection_id = data['collection_id']
|
collection_id = data['collection_id']
|
||||||
|
|||||||
@@ -27,6 +27,66 @@ def compiled_params(statement):
|
|||||||
return statement.compile().params
|
return statement.compile().params
|
||||||
|
|
||||||
|
|
||||||
|
class TestRagRerankAction:
|
||||||
|
"""Tests for RAG rerank action handler."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self):
|
||||||
|
mock_app = Mock()
|
||||||
|
mock_app.model_mgr = Mock()
|
||||||
|
mock_app.logger = Mock()
|
||||||
|
return mock_app
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invokes_rerank_model_and_sorts_scores(self, app):
|
||||||
|
"""Rerank action uses the selected model and returns top scores."""
|
||||||
|
provider = Mock()
|
||||||
|
provider.invoke_rerank = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{'index': 0, 'relevance_score': 0.2},
|
||||||
|
{'index': 1, 'relevance_score': 0.9},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rerank_model = SimpleNamespace(provider=provider)
|
||||||
|
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=rerank_model)
|
||||||
|
runtime_handler = make_handler(app)
|
||||||
|
|
||||||
|
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value]({
|
||||||
|
'rerank_model_uuid': 'rerank-1',
|
||||||
|
'query': 'hello',
|
||||||
|
'documents': ['a', 'b'],
|
||||||
|
'top_k': 1,
|
||||||
|
'extra_args': {'return_documents': False},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.code == 0
|
||||||
|
assert response.data['results'] == [{'index': 1, 'relevance_score': 0.9}]
|
||||||
|
app.model_mgr.get_rerank_model_by_uuid.assert_awaited_once_with('rerank-1')
|
||||||
|
provider.invoke_rerank.assert_awaited_once_with(
|
||||||
|
model=rerank_model,
|
||||||
|
query='hello',
|
||||||
|
documents=['a', 'b'],
|
||||||
|
extra_args={'return_documents': False},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_error_when_rerank_model_missing(self, app):
|
||||||
|
"""Missing rerank model returns an action error."""
|
||||||
|
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
|
||||||
|
side_effect=ValueError('not found')
|
||||||
|
)
|
||||||
|
runtime_handler = make_handler(app)
|
||||||
|
|
||||||
|
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value]({
|
||||||
|
'rerank_model_uuid': 'missing',
|
||||||
|
'query': 'hello',
|
||||||
|
'documents': ['a'],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.code != 0
|
||||||
|
assert 'Rerank model with rerank_model_uuid missing not found' in response.message
|
||||||
|
|
||||||
|
|
||||||
class TestInitializePluginSettings:
|
class TestInitializePluginSettings:
|
||||||
"""Tests for initialize_plugin_settings action handler."""
|
"""Tests for initialize_plugin_settings action handler."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user