diff --git a/pkg/api/http/service/knowledge.py b/pkg/api/http/service/knowledge.py index e42a14a7..b604e0ed 100644 --- a/pkg/api/http/service/knowledge.py +++ b/pkg/api/http/service/knowledge.py @@ -84,14 +84,29 @@ class KnowledgeService: async def delete_file(self, kb_uuid: str, file_id: str) -> None: """删除文件""" - await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id) - ) - # TODO: remove from memory + runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + if runtime_kb is None: + raise Exception('Knowledge base not found') + await runtime_kb.delete_file(file_id) async def delete_knowledge_base(self, kb_uuid: str) -> None: """删除知识库""" + await self.ap.rag_mgr.remove_knowledge_base(kb_uuid) + await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid) ) - # TODO: remove from memory + + # delete files + files = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid) + ) + for file in files: + # delete chunks + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file.uuid) + ) + # delete file + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid) + ) diff --git a/pkg/rag/knowledge/kbmgr.py b/pkg/rag/knowledge/kbmgr.py index 46be7f75..24f98ea2 100644 --- a/pkg/rag/knowledge/kbmgr.py +++ b/pkg/rag/knowledge/kbmgr.py @@ -1,5 +1,4 @@ from __future__ import annotations -import asyncio import traceback import uuid from .services import parser, chunker @@ -130,8 +129,21 @@ class RuntimeKnowledgeBase: ) return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model) + async def delete_file(self, file_id: str): + # delete vector + await self.ap.vector_db_mgr.vector_db.delete_by_file_id(self.knowledge_base_entity.uuid, file_id) + + # delete chunk + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file_id) + ) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id) + ) + async def dispose(self): - pass + await self.ap.vector_db_mgr.vector_db.delete_collection(self.knowledge_base_entity.uuid) class RAGManager: @@ -192,118 +204,3 @@ class RAGManager: await kb.dispose() self.knowledge_bases.remove(kb) return - - async def delete_data_by_file_id(self, file_id: str): - """ - Deletes all data associated with a specific file ID, including its chunks and vectors, - and the file record itself. - """ - self.ap.logger.info(f'Starting data deletion process for file_id: {file_id}') - session = SessionLocal() - try: - # delete vectors - await asyncio.to_thread(self.ap.vector_db_mgr.vector_db.delete_by_file_id_sync, file_id) - self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') - - chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() - for chunk in chunks_to_delete: - session.delete(chunk) - self.ap.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}') - - file_to_delete = session.query(File).filter_by(id=file_id).first() - if file_to_delete: - session.delete(file_to_delete) - try: - await self.ap.storage_mgr.storage_provider.delete(file_id) - except Exception as e: - self.ap.logger.error( - f'Error deleting file from storage for file_id {file_id}: {str(e)}', - exc_info=True, - ) - self.ap.logger.info(f'Deleted file record for file_id: {file_id}') - else: - self.ap.logger.warning( - f'File with ID {file_id} not found in database. Skipping deletion of file record.' - ) - session.commit() - self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}') - except Exception as e: - session.rollback() - self.ap.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True) - raise - finally: - session.close() - - async def delete_kb_by_id(self, kb_id: str): - """ - Deletes a knowledge base and all associated files, chunks, and vectors. - This involves querying for associated files and then deleting them. - """ - self.ap.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}') - session = SessionLocal() - - try: - kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() - if not kb_to_delete: - self.ap.logger.warning(f'Knowledge Base with ID {kb_id} not found.') - return - - files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() - - session.close() - - for file_obj in files_to_delete: - try: - await self.delete_data_by_file_id(file_obj.id) - except Exception as file_del_e: - self.ap.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}') - - session = SessionLocal() - try: - kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() - if kb_final_delete: - session.delete(kb_final_delete) - session.commit() - self.ap.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}') - else: - self.ap.logger.warning( - f'Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.' - ) - except Exception as kb_del_e: - session.rollback() - self.ap.logger.error( - f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', - exc_info=True, - ) - raise - finally: - session.close() - - except Exception as e: - # 如果在最初获取 KB 或文件列表时出错 - if session.is_active: - session.rollback() - self.ap.logger.error( - f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', - exc_info=True, - ) - raise - finally: - if session.is_active: - session.close() - - # async def get_file_content_by_file_id(self, file_id: str) -> str: - # file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) - - # _, ext = os.path.splitext(file_id.lower()) - # ext = ext.lstrip('.') - - # try: - # text = file_bytes.decode('utf-8') - # except UnicodeDecodeError: - # return '[非文本文件或编码无法识别]' - - # if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']: - # return text - # else: - # return f'[未知类型: .{ext}]' diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 6b019433..a0ae3d49 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -40,13 +40,7 @@ class Embedder(BaseService): ) # save embeddings to vdb - await self._run_sync( - self.ap.vector_db_mgr.vector_db.add_embeddings, - kb_id, - chunk_ids, - embeddings_list, - chunk_dicts, - ) + await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts) self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.') diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index fc403a57..73c7edaa 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -14,7 +14,9 @@ class Retriever(base_service.BaseService): async def retrieve( self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5 ) -> list[retriever_entities.RetrieveResultEntry]: - self.ap.logger.info(f"Retrieving for query: '{query}' with k={k} using {embedding_model.model_entity.uuid}") + self.ap.logger.info( + f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}" + ) query_embedding: list[float] = await embedding_model.requester.invoke_embedding( model=embedding_model, @@ -22,7 +24,7 @@ class Retriever(base_service.BaseService): extra_args={}, # TODO: add extra args ) - chroma_results = await self._run_sync(self.ap.vector_db_mgr.vector_db.search, kb_id, query_embedding[0], k) + chroma_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k) # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' matched_chroma_ids = chroma_results.get('ids', [[]])[0] diff --git a/pkg/vector/vdb.py b/pkg/vector/vdb.py index 2b7ca400..73a3cc0e 100644 --- a/pkg/vector/vdb.py +++ b/pkg/vector/vdb.py @@ -6,7 +6,7 @@ import numpy as np class VectorDatabase(abc.ABC): @abc.abstractmethod - def add_embeddings( + async def add_embeddings( self, collection: str, ids: list[str], @@ -18,16 +18,20 @@ class VectorDatabase(abc.ABC): pass @abc.abstractmethod - def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]: + async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]: """在指定 collection 中检索最相似的向量。""" pass @abc.abstractmethod - def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None: - """根据元数据删除指定 collection 中的向量。""" + async def delete_by_file_id(self, collection: str, file_id: str) -> None: + """根据 file_id 删除指定 collection 中的向量。""" pass @abc.abstractmethod - def get_or_create_collection(self, collection: str): + async def get_or_create_collection(self, collection: str): """获取或创建 collection。""" pass + + @abc.abstractmethod + async def delete_collection(self, collection: str): + pass diff --git a/pkg/vector/vdbs/chroma.py b/pkg/vector/vdbs/chroma.py index 8f295931..d7e705e5 100644 --- a/pkg/vector/vdbs/chroma.py +++ b/pkg/vector/vdbs/chroma.py @@ -1,9 +1,10 @@ from __future__ import annotations -import chromadb +import asyncio from typing import Any from chromadb import PersistentClient from pkg.vector.vdb import VectorDatabase from pkg.core import app +import chromadb class ChromaVectorDatabase(VectorDatabase): @@ -12,26 +13,29 @@ class ChromaVectorDatabase(VectorDatabase): self.client = PersistentClient(path=base_path) self._collections = {} - def get_or_create_collection(self, collection: str) -> chromadb.Collection: + async def get_or_create_collection(self, collection: str) -> chromadb.Collection: if collection not in self._collections: - self._collections[collection] = self.client.get_or_create_collection(name=collection) + self._collections[collection] = await asyncio.to_thread( + self.client.get_or_create_collection, name=collection + ) self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.") return self._collections[collection] - def add_embeddings( + async def add_embeddings( self, collection: str, ids: list[str], embeddings_list: list[list[float]], metadatas: list[dict[str, Any]], ) -> None: - col = self.get_or_create_collection(collection) - col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas) + col = await self.get_or_create_collection(collection) + await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas) self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.") - def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]: - col = self.get_or_create_collection(collection) - results = col.query( + async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]: + col = await self.get_or_create_collection(collection) + results = await asyncio.to_thread( + col.query, query_embeddings=query_embedding, n_results=k, include=['metadatas', 'distances', 'documents'], @@ -39,7 +43,13 @@ class ChromaVectorDatabase(VectorDatabase): self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.") return results - def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None: - col = self.get_or_create_collection(collection) - col.delete(where=where) - self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}") + async def delete_by_file_id(self, collection: str, file_id: str) -> None: + col = await self.get_or_create_collection(collection) + await asyncio.to_thread(col.delete, where={'file_id': file_id}) + self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}") + + async def delete_collection(self, collection: str): + if collection in self._collections: + del self._collections[collection] + await asyncio.to_thread(self.client.delete_collection, name=collection) + self.ap.logger.info(f"Chroma collection '{collection}' deleted.")