mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-04 12:56:02 +00:00
68 lines
2.9 KiB
Python
68 lines
2.9 KiB
Python
import numpy as np
|
|
import logging
|
|
from chromadb import PersistentClient
|
|
from pkg.core import app
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChromaIndexManager:
|
|
def __init__(self, ap: app.Application, collection_name: str = 'default_collection'):
|
|
self.ap = ap
|
|
chroma_data_path = './data/chroma'
|
|
self.client = PersistentClient(path=chroma_data_path)
|
|
self._collection_name = collection_name
|
|
self._collection = None
|
|
|
|
self.ap.logger.info(f'ChromaIndexManager initialized. Collection name: {self._collection_name}')
|
|
|
|
@property
|
|
def collection(self):
|
|
if self._collection is None:
|
|
self._collection = self.client.get_or_create_collection(name=self._collection_name)
|
|
self.ap.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.")
|
|
return self._collection
|
|
|
|
def add_embeddings_sync(
|
|
self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]
|
|
):
|
|
if (
|
|
embeddings.shape[0] != len(chunk_ids)
|
|
or embeddings.shape[0] != len(file_ids)
|
|
or embeddings.shape[0] != len(documents)
|
|
):
|
|
raise ValueError('Embedding, file_id, chunk_id, and document count mismatch.')
|
|
|
|
chroma_ids = [f'{file_id}_{chunk_id}' for file_id, chunk_id in zip(file_ids, chunk_ids)]
|
|
metadatas = [{'file_id': fid, 'chunk_id': cid} for fid, cid in zip(file_ids, chunk_ids)]
|
|
|
|
self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
|
|
self.collection.add(embeddings=embeddings.tolist(), ids=chroma_ids, metadatas=metadatas, documents=documents)
|
|
self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
|
|
|
|
def search_sync(self, query_embedding: np.ndarray, k: int = 5):
|
|
"""
|
|
Searches the Chroma collection for the top-k nearest neighbors.
|
|
Args:
|
|
query_embedding: A numpy array of the query embedding.
|
|
k: The number of results to return.
|
|
Returns:
|
|
A dictionary containing query results from Chroma.
|
|
"""
|
|
self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.")
|
|
results = self.collection.query(
|
|
query_embeddings=query_embedding.tolist(),
|
|
n_results=k,
|
|
# REMOVE 'ids' from the include list. It's returned by default.
|
|
include=['metadatas', 'distances', 'documents'],
|
|
)
|
|
self.logger.debug(f'Chroma search returned {len(results.get("ids", [[]])[0])} results.')
|
|
return results
|
|
|
|
def delete_by_file_id_sync(self, file_id: int):
|
|
self.logger.info(
|
|
f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'."
|
|
)
|
|
self.collection.delete(where={'file_id': file_id})
|
|
self.logger.info(f'Deleted embeddings for file_id: {file_id} from Chroma.')
|