mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-04 12:56:02 +00:00
289 lines
12 KiB
Python
289 lines
12 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
from typing import Any
|
|
from chromadb import PersistentClient
|
|
from langbot.pkg.vector.vdb import VectorDatabase, SearchType
|
|
from langbot.pkg.core import app
|
|
import chromadb
|
|
import chromadb.errors
|
|
|
|
# RRF smoothing constant (standard value from the literature)
|
|
_RRF_K = 60
|
|
|
|
|
|
class ChromaVectorDatabase(VectorDatabase):
|
|
def __init__(self, ap: app.Application, base_path: str = './data/chroma'):
|
|
self.ap = ap
|
|
self.client = PersistentClient(path=base_path)
|
|
self._collections = {}
|
|
|
|
@classmethod
|
|
def supported_search_types(cls) -> list[SearchType]:
|
|
return [SearchType.VECTOR, SearchType.FULL_TEXT, SearchType.HYBRID]
|
|
|
|
async def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
|
if collection not in self._collections:
|
|
self._collections[collection] = await asyncio.to_thread(
|
|
self.client.get_or_create_collection, name=collection
|
|
)
|
|
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
|
|
return self._collections[collection]
|
|
|
|
async def add_embeddings(
|
|
self,
|
|
collection: str,
|
|
ids: list[str],
|
|
embeddings_list: list[list[float]],
|
|
metadatas: list[dict[str, Any]],
|
|
documents: list[str] | None = None,
|
|
) -> None:
|
|
col = await self.get_or_create_collection(collection)
|
|
kwargs: dict[str, Any] = dict(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
|
if documents is not None:
|
|
kwargs['documents'] = documents
|
|
await asyncio.to_thread(col.upsert, **kwargs)
|
|
self.ap.logger.info(f"Upserted {len(ids)} embeddings to Chroma collection '{collection}'.")
|
|
|
|
async def search(
|
|
self,
|
|
collection: str,
|
|
query_embedding: list[float],
|
|
k: int = 5,
|
|
search_type: str = 'vector',
|
|
query_text: str = '',
|
|
filter: dict[str, Any] | None = None,
|
|
vector_weight: float | None = None,
|
|
) -> dict[str, Any]:
|
|
col = await self.get_or_create_collection(collection)
|
|
|
|
if search_type == SearchType.FULL_TEXT:
|
|
return await self._full_text_search(col, collection, k, query_text, filter)
|
|
elif search_type == SearchType.HYBRID:
|
|
return await self._hybrid_search(
|
|
col, collection, query_embedding, k, query_text, filter, vector_weight=vector_weight
|
|
)
|
|
|
|
# Default: vector search
|
|
return await self._vector_search(col, collection, query_embedding, k, filter)
|
|
|
|
async def _vector_search(
|
|
self,
|
|
col: chromadb.Collection,
|
|
collection: str,
|
|
query_embedding: list[float],
|
|
k: int,
|
|
filter: dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
query_kwargs: dict[str, Any] = dict(
|
|
query_embeddings=query_embedding,
|
|
n_results=k,
|
|
include=['metadatas', 'distances', 'documents'],
|
|
)
|
|
if filter:
|
|
query_kwargs['where'] = filter
|
|
results = await asyncio.to_thread(col.query, **query_kwargs)
|
|
self.ap.logger.info(
|
|
f"Chroma vector search in '{collection}' returned {len(results.get('ids', [[]])[0])} results."
|
|
)
|
|
return results
|
|
|
|
async def _full_text_search(
|
|
self,
|
|
col: chromadb.Collection,
|
|
collection: str,
|
|
k: int,
|
|
query_text: str,
|
|
filter: dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
if not query_text:
|
|
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
|
|
|
get_kwargs: dict[str, Any] = dict(
|
|
where_document={'$contains': query_text},
|
|
include=['metadatas', 'documents'],
|
|
limit=k,
|
|
)
|
|
if filter:
|
|
get_kwargs['where'] = filter
|
|
results = await asyncio.to_thread(col.get, **get_kwargs)
|
|
|
|
# col.get returns flat lists; wrap into column-major format.
|
|
# Distances are all 0.0 because Chroma's local $contains is a boolean
|
|
# filter with no relevance scoring. Chroma's BM25 sparse embedding
|
|
# function (ChromaBm25EmbeddingFunction) can generate scored sparse
|
|
# vectors, but sparse vector *indexing* is only available on Chroma
|
|
# Cloud, not locally. For ranked results, use hybrid mode or apply a
|
|
# reranker in a downstream stage.
|
|
ids = results.get('ids', [])
|
|
metadatas = results.get('metadatas', []) or [None] * len(ids)
|
|
documents = results.get('documents', []) or [None] * len(ids)
|
|
distances = [0.0] * len(ids)
|
|
|
|
self.ap.logger.info(f"Chroma full-text search in '{collection}' returned {len(ids)} results.")
|
|
return {'ids': [ids], 'metadatas': [metadatas], 'distances': [distances], 'documents': [documents]}
|
|
|
|
async def _hybrid_search(
|
|
self,
|
|
col: chromadb.Collection,
|
|
collection: str,
|
|
query_embedding: list[float],
|
|
k: int,
|
|
query_text: str,
|
|
filter: dict[str, Any] | None,
|
|
vector_weight: float | None = None,
|
|
) -> dict[str, Any]:
|
|
# Fall back to pure vector search when no text is provided
|
|
if not query_text:
|
|
return await self._vector_search(col, collection, query_embedding, k, filter)
|
|
|
|
# Run vector search and full-text search in parallel
|
|
vector_task = self._vector_search(col, collection, query_embedding, k, filter)
|
|
text_task = self._full_text_search(col, collection, k, query_text, filter)
|
|
vector_results, text_results = await asyncio.gather(vector_task, text_task)
|
|
|
|
vector_ids = vector_results.get('ids', [[]])[0]
|
|
text_ids = text_results.get('ids', [[]])[0]
|
|
|
|
if not vector_ids and not text_ids:
|
|
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
|
|
|
# RRF fusion
|
|
weights = None
|
|
if vector_weight is not None:
|
|
weights = [vector_weight, 1.0 - vector_weight]
|
|
self.ap.logger.info(
|
|
f"Chroma hybrid fusion config in '{collection}': "
|
|
f'vector_weight={vector_weight}, weights={weights or [1.0, 1.0]}, '
|
|
f'vector_hits={len(vector_ids)}, text_hits={len(text_ids)}'
|
|
)
|
|
fused = self._rrf_fuse([vector_ids, text_ids], k, weights=weights)
|
|
if not fused:
|
|
return {'ids': [[]], 'metadatas': [[]], 'distances': [[]], 'documents': [[]]}
|
|
|
|
fused_ids = [doc_id for doc_id, _ in fused]
|
|
|
|
# Fetch full metadata and documents for fused results
|
|
fetched = await asyncio.to_thread(col.get, ids=fused_ids, include=['metadatas', 'documents'])
|
|
|
|
# col.get returns results in arbitrary order; re-order to match fused ranking
|
|
fetched_map: dict[str, tuple] = {}
|
|
for i, fid in enumerate(fetched.get('ids', [])):
|
|
meta = (fetched.get('metadatas') or [None] * len(fetched['ids']))[i]
|
|
doc = (fetched.get('documents') or [None] * len(fetched['ids']))[i]
|
|
fetched_map[fid] = (meta, doc)
|
|
|
|
ordered_ids = []
|
|
ordered_metas = []
|
|
ordered_docs = []
|
|
ordered_dists = []
|
|
|
|
# Normalize RRF scores to 0~1 distances via min-max scaling.
|
|
# Raw RRF scores are tiny (e.g. 0.016~0.033 with k=60) so a naive
|
|
# ``1 - score`` would compress all distances into a narrow 0.96~0.98
|
|
# band with almost no discriminative power. Min-max normalization
|
|
# spreads them across the full 0~1 range (0.0 = best match).
|
|
max_score = fused[0][1]
|
|
min_score = fused[-1][1]
|
|
score_range = max_score - min_score
|
|
|
|
for doc_id, score in fused:
|
|
if doc_id in fetched_map:
|
|
meta, doc = fetched_map[doc_id]
|
|
ordered_ids.append(doc_id)
|
|
ordered_metas.append(meta)
|
|
ordered_docs.append(doc)
|
|
if score_range > 0:
|
|
ordered_dists.append(1.0 - (score - min_score) / score_range)
|
|
else:
|
|
ordered_dists.append(0.0)
|
|
|
|
self.ap.logger.info(
|
|
f"Chroma hybrid search in '{collection}' returned {len(ordered_ids)} results "
|
|
f'(vector={len(vector_ids)}, text={len(text_ids)}).'
|
|
)
|
|
return {
|
|
'ids': [ordered_ids],
|
|
'metadatas': [ordered_metas],
|
|
'distances': [ordered_dists],
|
|
'documents': [ordered_docs],
|
|
}
|
|
|
|
@staticmethod
|
|
def _rrf_fuse(result_lists: list[list[str]], k: int, weights: list[float] | None = None) -> list[tuple[str, float]]:
|
|
"""Reciprocal Rank Fusion over multiple ranked ID lists.
|
|
|
|
Returns a list of (doc_id, rrf_score) sorted by descending score,
|
|
truncated to *k* entries.
|
|
|
|
Args:
|
|
result_lists: Ranked ID lists from different search methods.
|
|
k: Number of results to return.
|
|
weights: Per-list weights. ``None`` means equal weight (1.0 each).
|
|
"""
|
|
if weights is None:
|
|
weights = [1.0] * len(result_lists)
|
|
scores: dict[str, float] = {}
|
|
for list_idx, ranked_ids in enumerate(result_lists):
|
|
w = weights[list_idx]
|
|
for rank, doc_id in enumerate(ranked_ids):
|
|
scores[doc_id] = scores.get(doc_id, 0.0) + w / (_RRF_K + rank + 1)
|
|
sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
|
return sorted_results[:k]
|
|
|
|
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
|
col = await self.get_or_create_collection(collection)
|
|
await asyncio.to_thread(col.delete, where={'file_id': file_id})
|
|
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}")
|
|
|
|
async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int:
|
|
col = await self.get_or_create_collection(collection)
|
|
await asyncio.to_thread(col.delete, where=filter)
|
|
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' by filter")
|
|
return 0 # Chroma delete does not return a count
|
|
|
|
async def list_by_filter(
|
|
self,
|
|
collection: str,
|
|
filter: dict[str, Any] | None = None,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
) -> tuple[list[dict[str, Any]], int]:
|
|
col = await self.get_or_create_collection(collection)
|
|
get_kwargs: dict[str, Any] = dict(
|
|
include=['metadatas', 'documents'],
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
if filter:
|
|
get_kwargs['where'] = filter
|
|
results = await asyncio.to_thread(col.get, **get_kwargs)
|
|
|
|
ids = results.get('ids', [])
|
|
metadatas = results.get('metadatas', []) or [None] * len(ids)
|
|
documents = results.get('documents', []) or [None] * len(ids)
|
|
|
|
items = []
|
|
for i, vid in enumerate(ids):
|
|
items.append(
|
|
{
|
|
'id': vid,
|
|
'document': documents[i] if i < len(documents) else None,
|
|
'metadata': metadatas[i] if i < len(metadatas) else {},
|
|
}
|
|
)
|
|
|
|
# Chroma col.count() gives total in collection; filtered count not available
|
|
total = await asyncio.to_thread(col.count) if not filter else -1
|
|
return items, total
|
|
|
|
async def delete_collection(self, collection: str):
|
|
if collection in self._collections:
|
|
del self._collections[collection]
|
|
|
|
try:
|
|
await asyncio.to_thread(self.client.delete_collection, name=collection)
|
|
except chromadb.errors.NotFoundError:
|
|
self.ap.logger.warning(f"Chroma collection '{collection}' not found.")
|
|
return
|
|
self.ap.logger.info(f"Chroma collection '{collection}' deleted.")
|