diff --git a/pkg/core/app.py b/pkg/core/app.py index ca2c5c1c..092676c6 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -29,6 +29,7 @@ from ..utils import logcache from . import taskmgr from . import entities as core_entities from ..rag.knowledge import mgr as rag_mgr +from ..vector import mgr as vectordb_mgr class Application: @@ -97,6 +98,8 @@ class Application: persistence_mgr: persistencemgr.PersistenceManager = None + vector_db_mgr: vectordb_mgr.VectorDBManager = None + http_ctrl: http_controller.HTTPController = None log_cache: logcache.LogCache = None diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 18240962..d9521274 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -21,6 +21,7 @@ from ...api.http.service import knowledge as knowledge_service from ...discover import engine as discover_engine from ...storage import mgr as storagemgr from ...utils import logcache +from ...vector import mgr as vectordb_mgr from .. import taskmgr @@ -94,6 +95,11 @@ class BuildAppStage(stage.BootingStage): await rag_mgr_inst.initialize_rag_system() ap.rag_mgr = rag_mgr_inst + # 初始化向量数据库管理器 + vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap) + await vectordb_mgr_inst.initialize() + ap.vector_db_mgr = vectordb_mgr_inst + http_ctrl = http_controller.HTTPController(ap) await http_ctrl.initialize() ap.http_ctrl = http_ctrl diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 28b8d666..6e5fe366 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -13,10 +13,9 @@ from pkg.rag.knowledge.services.database import ( from pkg.core import app from pkg.rag.knowledge.services.embedder import Embedder from pkg.rag.knowledge.services.retriever import Retriever -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager -from pkg.core import taskmgr -from ...entity.persistence import rag as persistence_rag import sqlalchemy +from ...entity.persistence import rag as persistence_rag +from pkg.core import taskmgr class RuntimeKnowledgeBase: @@ -24,8 +23,6 @@ class RuntimeKnowledgeBase: knowledge_base_entity: persistence_rag.KnowledgeBase - chroma_manager: ChromaIndexManager - parser: FileParser chunker: Chunker @@ -37,11 +34,12 @@ class RuntimeKnowledgeBase: def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase): self.ap = ap self.knowledge_base_entity = knowledge_base_entity - self.chroma_manager = ChromaIndexManager(ap=self.ap) self.parser = FileParser(ap=self.ap) self.chunker = Chunker(ap=self.ap) - self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager) - self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager) + self.embedder = Embedder(ap=self.ap) + self.retriever = Retriever(ap=self.ap) + # 传递kb_id给retriever + self.retriever.kb_id = knowledge_base_entity.uuid async def initialize(self): pass diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py deleted file mode 100644 index 17757b47..00000000 --- a/pkg/rag/knowledge/services/chroma_manager.py +++ /dev/null @@ -1,67 +0,0 @@ -import numpy as np -import logging -from chromadb import PersistentClient -from pkg.core import app - -logger = logging.getLogger(__name__) - - -class ChromaIndexManager: - def __init__(self, ap: app.Application, collection_name: str = 'default_collection'): - self.ap = ap - chroma_data_path = './data/chroma' - self.client = PersistentClient(path=chroma_data_path) - self._collection_name = collection_name - self._collection = None - - self.ap.logger.info(f'ChromaIndexManager initialized. Collection name: {self._collection_name}') - - @property - def collection(self): - if self._collection is None: - self._collection = self.client.get_or_create_collection(name=self._collection_name) - self.ap.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.") - return self._collection - - def add_embeddings_sync( - self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str] - ): - if ( - embeddings.shape[0] != len(chunk_ids) - or embeddings.shape[0] != len(file_ids) - or embeddings.shape[0] != len(documents) - ): - raise ValueError('Embedding, file_id, chunk_id, and document count mismatch.') - - chroma_ids = [f'{file_id}_{chunk_id}' for file_id, chunk_id in zip(file_ids, chunk_ids)] - metadatas = [{'file_id': fid, 'chunk_id': cid} for fid, cid in zip(file_ids, chunk_ids)] - - self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") - self.collection.add(embeddings=embeddings.tolist(), ids=chroma_ids, metadatas=metadatas, documents=documents) - self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") - - def search_sync(self, query_embedding: np.ndarray, k: int = 5): - """ - Searches the Chroma collection for the top-k nearest neighbors. - Args: - query_embedding: A numpy array of the query embedding. - k: The number of results to return. - Returns: - A dictionary containing query results from Chroma. - """ - self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.") - results = self.collection.query( - query_embeddings=query_embedding.tolist(), - n_results=k, - # REMOVE 'ids' from the include list. It's returned by default. - include=['metadatas', 'distances', 'documents'], - ) - self.logger.debug(f'Chroma search returned {len(results.get("ids", [[]])[0])} results.') - return results - - def delete_by_file_id_sync(self, file_id: int): - self.logger.info( - f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'." - ) - self.collection.delete(where={'file_id': file_id}) - self.logger.info(f'Deleted embeddings for file_id: {file_id} from Chroma.') diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 34165eab..213896a1 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -1,21 +1,17 @@ from __future__ import annotations import asyncio -import logging import numpy as np from typing import List from sqlalchemy.orm import Session from pkg.rag.knowledge.services.base_service import BaseService from pkg.rag.knowledge.services.database import Chunk, SessionLocal -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager from ....core import app from ....provider.modelmgr.requester import RuntimeEmbeddingModel class Embedder(BaseService): - def __init__(self, ap: app.Application, chroma_manager: ChromaIndexManager = None) -> None: + def __init__(self, ap: app.Application) -> None: super().__init__() - self.logger = logging.getLogger(self.__class__.__name__) - self.chroma_manager = chroma_manager self.ap = ap def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]): @@ -24,22 +20,19 @@ class Embedder(BaseService): This function assumes it's called within a context where the session will be committed/rolled back and closed by the caller. """ - self.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).') + self.ap.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).') chunk_objects = [] for text in chunks_texts: chunk = Chunk(file_id=file_id, text=text) session.add(chunk) chunk_objects.append(chunk) session.flush() # This populates the .id attribute for each new chunk object - self.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.') + self.ap.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.') return chunk_objects async def embed_and_store( self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel ) -> List[Chunk]: - if not embedding_model: - raise RuntimeError('Embedding model not loaded. Please check Embedder initialization.') - session = SessionLocal() # Start a session that will live for the whole operation chunk_objects = [] try: @@ -50,7 +43,7 @@ class Embedder(BaseService): session.commit() # Commit chunks to make their IDs permanent and accessible if not chunk_objects: - self.logger.warning( + self.ap.logger.warning( f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.' ) return [] @@ -67,23 +60,28 @@ class Embedder(BaseService): embeddings_np = np.array(embeddings, dtype=np.float32) - self.logger.info('Saving embeddings to Chroma...') chunk_ids = [c.id for c in chunk_objects] - file_ids_for_chroma = [file_id] * len(chunk_ids) - - await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call - self.chroma_manager.add_embeddings_sync, - file_ids_for_chroma, - chunk_ids, + # collection名用kb_id(file对象有kb_id字段) + kb_id = session.query(Chunk).filter_by(id=chunk_ids[0]).first().file.kb_id if chunk_ids else None + if not kb_id: + self.ap.logger.warning('无法获取kb_id,向量存储失败') + return chunk_objects + chroma_ids = [f'{file_id}_{cid}' for cid in chunk_ids] + metadatas = [{'file_id': file_id, 'chunk_id': cid} for cid in chunk_ids] + await self._run_sync( + self.ap.vector_db_mgr.vector_db.add_embeddings, + kb_id, + chroma_ids, embeddings_np, - chunks, # Pass original chunks texts for documents + metadatas, + chunks, ) - self.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to Chroma.') + self.ap.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to VectorDB.') return chunk_objects except Exception as e: session.rollback() # Rollback on any error - self.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True) + self.ap.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True) raise # Re-raise the exception to propagate it finally: session.close() # Ensure the session is always closed diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index d330747c..3385021a 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -5,18 +5,18 @@ from typing import List, Dict, Any from sqlalchemy.orm import Session from pkg.rag.knowledge.services.base_service import BaseService from pkg.rag.knowledge.services.database import Chunk, SessionLocal -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager +from pkg.vector.vdb import VectorDatabase from ....core import app logger = logging.getLogger(__name__) class Retriever(BaseService): - def __init__(self, ap:app.Application, chroma_manager: ChromaIndexManager): + def __init__(self, ap: app.Application): super().__init__() self.logger = logging.getLogger(self.__class__.__name__) - self.chroma_manager = chroma_manager self.ap = ap + self.vector_db: VectorDatabase = ap.vector_db_mgr.vector_db async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: if not self.embedding_model: @@ -27,7 +27,12 @@ class Retriever(BaseService): query_embedding: List[float] = await self.embedding_model.embed_query(query) query_embedding_np = np.array([query_embedding], dtype=np.float32) - chroma_results = await self._run_sync(self.chroma_manager.search_sync, query_embedding_np, k) + # collection名用kb_id(假设retriever有kb_id属性或通过ap传递) + kb_id = getattr(self, 'kb_id', None) + if not kb_id: + self.logger.warning('无法获取kb_id,向量检索失败') + return [] + chroma_results = await self._run_sync(self.vector_db.search, kb_id, query_embedding_np, k) # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' matched_chroma_ids = chroma_results.get('ids', [[]])[0] diff --git a/pkg/vector/mgr.py b/pkg/vector/mgr.py index b2f47d61..ea198ac2 100644 --- a/pkg/vector/mgr.py +++ b/pkg/vector/mgr.py @@ -1,13 +1,18 @@ from __future__ import annotations from ..core import app +from .vdb import VectorDatabase +from .vdbs.chroma import ChromaVectorDatabase class VectorDBManager: ap: app.Application + vector_db: VectorDatabase = None def __init__(self, ap: app.Application): self.ap = ap async def initialize(self): - pass + # 初始化 Chroma 向量数据库(可扩展为多种实现) + if self.vector_db is None: + self.vector_db = ChromaVectorDatabase(self.ap) diff --git a/pkg/vector/vdb.py b/pkg/vector/vdb.py index 100ded93..20eff831 100644 --- a/pkg/vector/vdb.py +++ b/pkg/vector/vdb.py @@ -1,7 +1,33 @@ from __future__ import annotations - import abc +from typing import Any, List, Dict +import numpy as np class VectorDatabase(abc.ABC): - pass + @abc.abstractmethod + def add_embeddings( + self, + collection: str, + ids: List[str], + embeddings: np.ndarray, + metadatas: List[Dict[str, Any]], + documents: List[str], + ) -> None: + """向指定 collection 添加向量数据。""" + pass + + @abc.abstractmethod + def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]: + """在指定 collection 中检索最相似的向量。""" + pass + + @abc.abstractmethod + def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None: + """根据元数据删除指定 collection 中的向量。""" + pass + + @abc.abstractmethod + def get_or_create_collection(self, collection: str): + """获取或创建 collection。""" + pass diff --git a/pkg/vector/vdbs/__init__.py b/pkg/vector/vdbs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/vector/vdbs/chroma.py b/pkg/vector/vdbs/chroma.py new file mode 100644 index 00000000..c249c0ba --- /dev/null +++ b/pkg/vector/vdbs/chroma.py @@ -0,0 +1,46 @@ +from __future__ import annotations +import numpy as np +from typing import Any, List, Dict +from chromadb import PersistentClient +from pkg.vector.vdb import VectorDatabase +from pkg.core import app + + +class ChromaVectorDatabase(VectorDatabase): + def __init__(self, ap: app.Application, base_path: str = './data/chroma'): + self.ap = ap + self.client = PersistentClient(path=base_path) + self._collections = {} + + def get_or_create_collection(self, collection: str): + if collection not in self._collections: + self._collections[collection] = self.client.get_or_create_collection(name=collection) + self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.") + return self._collections[collection] + + def add_embeddings( + self, + collection: str, + ids: List[str], + embeddings: np.ndarray, + metadatas: List[Dict[str, Any]], + documents: List[str], + ) -> None: + col = self.get_or_create_collection(collection) + col.add(embeddings=embeddings.tolist(), ids=ids, metadatas=metadatas, documents=documents) + self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.") + + def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]: + col = self.get_or_create_collection(collection) + results = col.query( + query_embeddings=query_embedding.tolist(), + n_results=k, + include=['metadatas', 'distances', 'documents'], + ) + self.ap.logger.debug(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.") + return results + + def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None: + col = self.get_or_create_collection(collection) + col.delete(where=where) + self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")