import numpy as np import logging from chromadb import PersistentClient from pkg.core import app logger = logging.getLogger(__name__) class ChromaIndexManager: def __init__(self, ap: app.Application, collection_name: str = 'default_collection'): self.ap = ap chroma_data_path = './data/chroma' self.client = PersistentClient(path=chroma_data_path) self._collection_name = collection_name self._collection = None self.ap.logger.info(f'ChromaIndexManager initialized. Collection name: {self._collection_name}') @property def collection(self): if self._collection is None: self._collection = self.client.get_or_create_collection(name=self._collection_name) self.ap.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.") return self._collection def add_embeddings_sync( self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str] ): if ( embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents) ): raise ValueError('Embedding, file_id, chunk_id, and document count mismatch.') chroma_ids = [f'{file_id}_{chunk_id}' for file_id, chunk_id in zip(file_ids, chunk_ids)] metadatas = [{'file_id': fid, 'chunk_id': cid} for fid, cid in zip(file_ids, chunk_ids)] self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") self.collection.add(embeddings=embeddings.tolist(), ids=chroma_ids, metadatas=metadatas, documents=documents) self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") def search_sync(self, query_embedding: np.ndarray, k: int = 5): """ Searches the Chroma collection for the top-k nearest neighbors. Args: query_embedding: A numpy array of the query embedding. k: The number of results to return. Returns: A dictionary containing query results from Chroma. """ self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.") results = self.collection.query( query_embeddings=query_embedding.tolist(), n_results=k, # REMOVE 'ids' from the include list. It's returned by default. include=['metadatas', 'distances', 'documents'], ) self.logger.debug(f'Chroma search returned {len(results.get("ids", [[]])[0])} results.') return results def delete_by_file_id_sync(self, file_id: int): self.logger.info( f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'." ) self.collection.delete(where={'file_id': file_id}) self.logger.info(f'Deleted embeddings for file_id: {file_id} from Chroma.')