mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-03 20:44:36 +00:00
feat(rag): expose vector listing API with backend filter support
This commit is contained in:
@@ -555,6 +555,20 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||
|
||||
@self.action(PluginToRuntimeAction.VECTOR_LIST)
|
||||
async def vector_list(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
collection_id = data['collection_id']
|
||||
filters = data.get('filters')
|
||||
limit = data.get('limit', 20)
|
||||
offset = data.get('offset', 0)
|
||||
try:
|
||||
items, total = await self.ap.rag_runtime_service.vector_list(
|
||||
collection_id, filters, limit, offset
|
||||
)
|
||||
return handler.ActionResponse.success(data={'items': items, 'total': total})
|
||||
except Exception as e:
|
||||
return _make_rag_error_response(e, 'VectorStoreError', collection_id=collection_id)
|
||||
|
||||
@self.action(PluginToRuntimeAction.GET_KNOWLEDEGE_FILE_STREAM)
|
||||
async def get_knowledge_file_stream(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
storage_path = data['storage_path']
|
||||
|
||||
@@ -75,6 +75,31 @@ class RAGRuntimeService:
|
||||
count = await self.ap.vector_db_mgr.delete_by_filter(collection_name=collection_id, filter=filters)
|
||||
return count
|
||||
|
||||
async def vector_list(
|
||||
self,
|
||||
collection_id: str,
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Handle VECTOR_LIST action.
|
||||
|
||||
Args:
|
||||
collection_id: The collection to list from.
|
||||
filters: Optional metadata filters.
|
||||
limit: Maximum number of items to return.
|
||||
offset: Number of items to skip.
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total).
|
||||
"""
|
||||
return await self.ap.vector_db_mgr.list_by_filter(
|
||||
collection_name=collection_id,
|
||||
filter=filters,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
async def get_file_stream(self, storage_path: str) -> bytes:
|
||||
"""Handle GET_KNOWLEDEGE_FILE_STREAM action.
|
||||
|
||||
|
||||
@@ -49,17 +49,25 @@ def normalize_filter(
|
||||
def strip_unsupported_fields(
|
||||
triples: list[tuple[str, str, Any]],
|
||||
supported_fields: set[str],
|
||||
field_aliases: dict[str, str] | None = None,
|
||||
) -> list[tuple[str, str, Any]]:
|
||||
"""Return only triples whose field is in *supported_fields*.
|
||||
|
||||
If *field_aliases* is provided, aliased field names are mapped to the
|
||||
canonical backend name before the support check. For example,
|
||||
``{'uuid': 'chunk_uuid'}`` allows callers to use ``uuid`` which is
|
||||
transparently rewritten to ``chunk_uuid``.
|
||||
|
||||
Dropped fields are logged at WARNING level so the caller knows they were
|
||||
silently ignored (useful for Milvus / pgvector which only store a fixed
|
||||
schema).
|
||||
"""
|
||||
aliases = field_aliases or {}
|
||||
kept: list[tuple[str, str, Any]] = []
|
||||
for field, op, value in triples:
|
||||
if field in supported_fields:
|
||||
kept.append((field, op, value))
|
||||
resolved = aliases.get(field, field)
|
||||
if resolved in supported_fields:
|
||||
kept.append((resolved, op, value))
|
||||
else:
|
||||
logger.warning(
|
||||
'Filter field %r is not supported by this backend and will be ignored (supported: %s)',
|
||||
|
||||
@@ -157,3 +157,17 @@ class VectorDBManager:
|
||||
Number of deleted vectors (best-effort; some backends return 0).
|
||||
"""
|
||||
return await self.vector_db.delete_by_filter(collection_name, filter)
|
||||
|
||||
async def list_by_filter(
|
||||
self,
|
||||
collection_name: str,
|
||||
filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Proxy: List vectors by metadata filter with pagination.
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total).
|
||||
"""
|
||||
return await self.vector_db.list_by_filter(collection_name, filter, limit, offset)
|
||||
|
||||
@@ -92,6 +92,28 @@ class VectorDatabase(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def list_by_filter(
|
||||
self,
|
||||
collection: str,
|
||||
filter: dict[str, Any] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""List vectors matching the given metadata filter with pagination.
|
||||
|
||||
Args:
|
||||
collection: Collection name.
|
||||
filter: Optional metadata filter dict in canonical format.
|
||||
limit: Maximum number of items to return.
|
||||
offset: Number of items to skip.
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total) where items is a list of dicts with
|
||||
keys 'id', 'document', 'metadata', and total is the best-effort
|
||||
count of all matching vectors (-1 if unknown).
|
||||
"""
|
||||
return [], -1
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create collection."""
|
||||
|
||||
@@ -221,6 +221,39 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' by filter")
|
||||
return 0 # Chroma delete does not return a count
|
||||
|
||||
async def list_by_filter(
|
||||
self,
|
||||
collection: str,
|
||||
filter: dict[str, Any] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
get_kwargs: dict[str, Any] = dict(
|
||||
include=['metadatas', 'documents'],
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if filter:
|
||||
get_kwargs['where'] = filter
|
||||
results = await asyncio.to_thread(col.get, **get_kwargs)
|
||||
|
||||
ids = results.get('ids', [])
|
||||
metadatas = results.get('metadatas', []) or [None] * len(ids)
|
||||
documents = results.get('documents', []) or [None] * len(ids)
|
||||
|
||||
items = []
|
||||
for i, vid in enumerate(ids):
|
||||
items.append({
|
||||
'id': vid,
|
||||
'document': documents[i] if i < len(documents) else None,
|
||||
'metadata': metadatas[i] if i < len(metadatas) else {},
|
||||
})
|
||||
|
||||
# Chroma col.count() gives total in collection; filtered count not available
|
||||
total = await asyncio.to_thread(col.count) if not filter else -1
|
||||
return items, total
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
|
||||
@@ -11,11 +11,14 @@ from langbot.pkg.core import app
|
||||
# silently dropped with a warning.
|
||||
_MILVUS_SUPPORTED_FIELDS = {'text', 'file_id', 'chunk_uuid'}
|
||||
|
||||
# Callers use canonical metadata key 'uuid' but Milvus stores it as 'chunk_uuid'.
|
||||
_MILVUS_FIELD_ALIASES = {'uuid': 'chunk_uuid'}
|
||||
|
||||
|
||||
def _build_milvus_expr(filter_dict: dict[str, Any]) -> str:
|
||||
"""Translate canonical filter dict into a Milvus boolean expression string."""
|
||||
triples = normalize_filter(filter_dict)
|
||||
triples = strip_unsupported_fields(triples, _MILVUS_SUPPORTED_FIELDS)
|
||||
triples = strip_unsupported_fields(triples, _MILVUS_SUPPORTED_FIELDS, _MILVUS_FIELD_ALIASES)
|
||||
if not triples:
|
||||
return ''
|
||||
|
||||
@@ -340,6 +343,60 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' by filter")
|
||||
return 0 # Milvus delete does not return a count
|
||||
|
||||
async def list_by_filter(
|
||||
self,
|
||||
collection: str,
|
||||
filter: dict[str, Any] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
collection = self._normalize_collection_name(collection)
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
query_kwargs: dict[str, Any] = dict(
|
||||
collection_name=collection,
|
||||
output_fields=['text', 'file_id', 'chunk_uuid'],
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if filter:
|
||||
expr = _build_milvus_expr(filter)
|
||||
if expr:
|
||||
query_kwargs['filter'] = expr
|
||||
|
||||
results = await asyncio.to_thread(self.client.query, **query_kwargs)
|
||||
|
||||
items = []
|
||||
for row in results:
|
||||
items.append({
|
||||
'id': row.get('id', ''),
|
||||
'document': row.get('text'),
|
||||
'metadata': {
|
||||
'text': row.get('text', ''),
|
||||
'file_id': row.get('file_id', ''),
|
||||
'uuid': row.get('chunk_uuid', ''),
|
||||
},
|
||||
})
|
||||
|
||||
# Milvus query with count(*)
|
||||
total = -1
|
||||
try:
|
||||
count_kwargs: dict[str, Any] = dict(
|
||||
collection_name=collection,
|
||||
output_fields=['count(*)'],
|
||||
)
|
||||
if filter:
|
||||
expr = _build_milvus_expr(filter)
|
||||
if expr:
|
||||
count_kwargs['filter'] = expr
|
||||
count_result = await asyncio.to_thread(self.client.query, **count_kwargs)
|
||||
if count_result:
|
||||
total = count_result[0].get('count(*)', -1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return items, total
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete a Milvus collection
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@ Base = declarative_base()
|
||||
# pgvector schema only stores these metadata fields.
|
||||
_PG_SUPPORTED_FIELDS = {'text', 'file_id', 'chunk_uuid'}
|
||||
|
||||
# Callers use canonical metadata key 'uuid' but pgvector stores it as 'chunk_uuid'.
|
||||
_PG_FIELD_ALIASES = {'uuid': 'chunk_uuid'}
|
||||
|
||||
# Map schema field names to SQLAlchemy columns (resolved lazily from PgVectorEntry).
|
||||
_PG_COLUMN_MAP = {
|
||||
'text': 'text',
|
||||
@@ -37,7 +40,7 @@ class PgVectorEntry(Base):
|
||||
def _build_pg_conditions(filter_dict: dict[str, Any]) -> list:
|
||||
"""Translate canonical filter dict into a list of SQLAlchemy conditions."""
|
||||
triples = normalize_filter(filter_dict)
|
||||
triples = strip_unsupported_fields(triples, _PG_SUPPORTED_FIELDS)
|
||||
triples = strip_unsupported_fields(triples, _PG_SUPPORTED_FIELDS, _PG_FIELD_ALIASES)
|
||||
|
||||
conditions = []
|
||||
for field, op, value in triples:
|
||||
@@ -309,6 +312,65 @@ class PgVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.error(f'Error deleting from pgvector by filter: {e}')
|
||||
raise
|
||||
|
||||
async def list_by_filter(
|
||||
self,
|
||||
collection: str,
|
||||
filter: dict[str, Any] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
from sqlalchemy import select, func
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
PgVectorEntry.id,
|
||||
PgVectorEntry.text,
|
||||
PgVectorEntry.file_id,
|
||||
PgVectorEntry.chunk_uuid,
|
||||
)
|
||||
.filter(PgVectorEntry.collection == collection)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(PgVectorEntry)
|
||||
.filter(PgVectorEntry.collection == collection)
|
||||
)
|
||||
|
||||
if filter:
|
||||
for cond in _build_pg_conditions(filter):
|
||||
stmt = stmt.filter(cond)
|
||||
count_stmt = count_stmt.filter(cond)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
rows = result.fetchall()
|
||||
|
||||
count_result = await session.execute(count_stmt)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
items = []
|
||||
for row in rows:
|
||||
items.append({
|
||||
'id': row.id,
|
||||
'document': row.text or '',
|
||||
'metadata': {
|
||||
'text': row.text or '',
|
||||
'file_id': row.file_id or '',
|
||||
'uuid': row.chunk_uuid or '',
|
||||
},
|
||||
})
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error listing from pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete all vectors in a collection
|
||||
|
||||
|
||||
@@ -150,6 +150,95 @@ class QdrantVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.info(f"Deleted embeddings from Qdrant collection '{collection}' by filter")
|
||||
return 0 # Qdrant delete does not return a count
|
||||
|
||||
async def list_by_filter(
|
||||
self,
|
||||
collection: str,
|
||||
filter: dict[str, Any] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
exists = await self.client.collection_exists(collection)
|
||||
if not exists:
|
||||
return [], 0
|
||||
|
||||
qdrant_filter = _build_qdrant_filter(filter) if filter else None
|
||||
|
||||
# Qdrant scroll uses cursor-based pagination (offset = point ID),
|
||||
# not numeric skip. To support numeric offset we scroll through
|
||||
# `offset + limit` items and discard the first `offset`.
|
||||
remaining_to_skip = offset
|
||||
remaining_to_collect = limit
|
||||
cursor: int | str | None = None
|
||||
collected: list[dict[str, Any]] = []
|
||||
|
||||
while remaining_to_skip > 0 or remaining_to_collect > 0:
|
||||
batch_size = remaining_to_skip + remaining_to_collect if remaining_to_skip > 0 else remaining_to_collect
|
||||
scroll_kwargs: dict[str, Any] = dict(
|
||||
collection_name=collection,
|
||||
limit=min(batch_size, 256),
|
||||
with_payload=True if remaining_to_skip == 0 else False,
|
||||
with_vectors=False,
|
||||
)
|
||||
if qdrant_filter:
|
||||
scroll_kwargs['scroll_filter'] = qdrant_filter
|
||||
if cursor is not None:
|
||||
scroll_kwargs['offset'] = cursor
|
||||
|
||||
points, next_cursor = await self.client.scroll(**scroll_kwargs)
|
||||
if not points:
|
||||
break
|
||||
|
||||
for point in points:
|
||||
if remaining_to_skip > 0:
|
||||
remaining_to_skip -= 1
|
||||
continue
|
||||
if remaining_to_collect <= 0:
|
||||
break
|
||||
# Re-fetch payload if we skipped it during the skip phase
|
||||
payload = point.payload or {}
|
||||
collected.append({
|
||||
'id': str(point.id),
|
||||
'document': payload.get('text') or payload.get('document'),
|
||||
'metadata': payload,
|
||||
})
|
||||
remaining_to_collect -= 1
|
||||
|
||||
if next_cursor is None:
|
||||
break
|
||||
cursor = next_cursor
|
||||
|
||||
# If we skipped without payload, re-fetch the collected items' payloads
|
||||
# (only needed when offset > 0 and items were collected in a skip batch)
|
||||
if offset > 0 and collected:
|
||||
refetch_ids = [item['id'] for item in collected if not item.get('metadata')]
|
||||
if refetch_ids:
|
||||
fetched_points = await self.client.retrieve(
|
||||
collection_name=collection,
|
||||
ids=refetch_ids,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
payload_map = {str(p.id): p.payload or {} for p in fetched_points}
|
||||
for item in collected:
|
||||
if item['id'] in payload_map:
|
||||
payload = payload_map[item['id']]
|
||||
item['metadata'] = payload
|
||||
item['document'] = payload.get('text') or payload.get('document')
|
||||
|
||||
# Use count() for accurate total (supports filter)
|
||||
total = -1
|
||||
try:
|
||||
count_result = await self.client.count(
|
||||
collection_name=collection,
|
||||
count_filter=qdrant_filter,
|
||||
exact=True,
|
||||
)
|
||||
total = count_result.count
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return collected, total
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
try:
|
||||
await self.client.delete_collection(collection)
|
||||
|
||||
@@ -340,6 +340,48 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.info(f"Deleted embeddings from SeekDB collection '{collection}' by filter")
|
||||
return 0 # SeekDB delete does not return a count
|
||||
|
||||
async def list_by_filter(
|
||||
self,
|
||||
collection: str,
|
||||
filter: Dict[str, Any] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[Dict[str, Any]], int]:
|
||||
exists = await asyncio.to_thread(self.client.has_collection, collection)
|
||||
if not exists:
|
||||
return [], 0
|
||||
|
||||
if collection not in self._collections:
|
||||
coll = await asyncio.to_thread(self.client.get_collection, collection, embedding_function=None)
|
||||
self._collections[collection] = coll
|
||||
else:
|
||||
coll = self._collections[collection]
|
||||
|
||||
get_kwargs: Dict[str, Any] = dict(
|
||||
include=['metadatas', 'documents'],
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if filter:
|
||||
get_kwargs['where'] = filter
|
||||
|
||||
results = await asyncio.to_thread(coll.get, **get_kwargs)
|
||||
|
||||
ids = results.get('ids', [])
|
||||
metadatas = results.get('metadatas', []) or [None] * len(ids)
|
||||
documents = results.get('documents', []) or [None] * len(ids)
|
||||
|
||||
items = []
|
||||
for i, vid in enumerate(ids):
|
||||
items.append({
|
||||
'id': vid,
|
||||
'document': documents[i] if i < len(documents) else None,
|
||||
'metadata': metadatas[i] if i < len(metadatas) else {},
|
||||
})
|
||||
|
||||
total = await asyncio.to_thread(coll.count) if not filter else -1
|
||||
return items, total
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete the entire collection.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user