feat(rag): all APIs ok

This commit is contained in:
Junyan Qin
2025-07-16 22:15:03 +08:00
parent 2f2db4d445
commit 333ec346ef
7 changed files with 71 additions and 149 deletions

View File

@@ -84,14 +84,29 @@ class KnowledgeService:
async def delete_file(self, kb_uuid: str, file_id: str) -> None:
"""删除文件"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
)
# TODO: remove from memory
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
if runtime_kb is None:
raise Exception('Knowledge base not found')
await runtime_kb.delete_file(file_id)
async def delete_knowledge_base(self, kb_uuid: str) -> None:
"""删除知识库"""
await self.ap.rag_mgr.remove_knowledge_base(kb_uuid)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
)
# TODO: remove from memory
# delete files
files = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
)
for file in files:
# delete chunks
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file.uuid)
)
# delete file
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid)
)

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import asyncio
import traceback
import uuid
from .services import parser, chunker
@@ -130,8 +129,21 @@ class RuntimeKnowledgeBase:
)
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model)
async def delete_file(self, file_id: str):
# delete vector
await self.ap.vector_db_mgr.vector_db.delete_by_file_id(self.knowledge_base_entity.uuid, file_id)
# delete chunk
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file_id)
)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
)
async def dispose(self):
pass
await self.ap.vector_db_mgr.vector_db.delete_collection(self.knowledge_base_entity.uuid)
class RAGManager:
@@ -192,118 +204,3 @@ class RAGManager:
await kb.dispose()
self.knowledge_bases.remove(kb)
return
async def delete_data_by_file_id(self, file_id: str):
"""
Deletes all data associated with a specific file ID, including its chunks and vectors,
and the file record itself.
"""
self.ap.logger.info(f'Starting data deletion process for file_id: {file_id}')
session = SessionLocal()
try:
# delete vectors
await asyncio.to_thread(self.ap.vector_db_mgr.vector_db.delete_by_file_id_sync, file_id)
self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}')
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
for chunk in chunks_to_delete:
session.delete(chunk)
self.ap.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}')
file_to_delete = session.query(File).filter_by(id=file_id).first()
if file_to_delete:
session.delete(file_to_delete)
try:
await self.ap.storage_mgr.storage_provider.delete(file_id)
except Exception as e:
self.ap.logger.error(
f'Error deleting file from storage for file_id {file_id}: {str(e)}',
exc_info=True,
)
self.ap.logger.info(f'Deleted file record for file_id: {file_id}')
else:
self.ap.logger.warning(
f'File with ID {file_id} not found in database. Skipping deletion of file record.'
)
session.commit()
self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}')
except Exception as e:
session.rollback()
self.ap.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True)
raise
finally:
session.close()
async def delete_kb_by_id(self, kb_id: str):
"""
Deletes a knowledge base and all associated files, chunks, and vectors.
This involves querying for associated files and then deleting them.
"""
self.ap.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}')
session = SessionLocal()
try:
kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if not kb_to_delete:
self.ap.logger.warning(f'Knowledge Base with ID {kb_id} not found.')
return
files_to_delete = session.query(File).filter_by(kb_id=kb_id).all()
session.close()
for file_obj in files_to_delete:
try:
await self.delete_data_by_file_id(file_obj.id)
except Exception as file_del_e:
self.ap.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}')
session = SessionLocal()
try:
kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if kb_final_delete:
session.delete(kb_final_delete)
session.commit()
self.ap.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}')
else:
self.ap.logger.warning(
f'Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.'
)
except Exception as kb_del_e:
session.rollback()
self.ap.logger.error(
f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}',
exc_info=True,
)
raise
finally:
session.close()
except Exception as e:
# 如果在最初获取 KB 或文件列表时出错
if session.is_active:
session.rollback()
self.ap.logger.error(
f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}',
exc_info=True,
)
raise
finally:
if session.is_active:
session.close()
# async def get_file_content_by_file_id(self, file_id: str) -> str:
# file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
# _, ext = os.path.splitext(file_id.lower())
# ext = ext.lstrip('.')
# try:
# text = file_bytes.decode('utf-8')
# except UnicodeDecodeError:
# return '[非文本文件或编码无法识别]'
# if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']:
# return text
# else:
# return f'[未知类型: .{ext}]'

View File

@@ -40,13 +40,7 @@ class Embedder(BaseService):
)
# save embeddings to vdb
await self._run_sync(
self.ap.vector_db_mgr.vector_db.add_embeddings,
kb_id,
chunk_ids,
embeddings_list,
chunk_dicts,
)
await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts)
self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')

View File

@@ -14,7 +14,9 @@ class Retriever(base_service.BaseService):
async def retrieve(
self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
) -> list[retriever_entities.RetrieveResultEntry]:
self.ap.logger.info(f"Retrieving for query: '{query}' with k={k} using {embedding_model.model_entity.uuid}")
self.ap.logger.info(
f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}"
)
query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
model=embedding_model,
@@ -22,7 +24,7 @@ class Retriever(base_service.BaseService):
extra_args={}, # TODO: add extra args
)
chroma_results = await self._run_sync(self.ap.vector_db_mgr.vector_db.search, kb_id, query_embedding[0], k)
chroma_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k)
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
matched_chroma_ids = chroma_results.get('ids', [[]])[0]

View File

@@ -6,7 +6,7 @@ import numpy as np
class VectorDatabase(abc.ABC):
@abc.abstractmethod
def add_embeddings(
async def add_embeddings(
self,
collection: str,
ids: list[str],
@@ -18,16 +18,20 @@ class VectorDatabase(abc.ABC):
pass
@abc.abstractmethod
def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
"""在指定 collection 中检索最相似的向量。"""
pass
@abc.abstractmethod
def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None:
"""根据元数据删除指定 collection 中的向量。"""
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
"""根据 file_id 删除指定 collection 中的向量。"""
pass
@abc.abstractmethod
def get_or_create_collection(self, collection: str):
async def get_or_create_collection(self, collection: str):
"""获取或创建 collection。"""
pass
@abc.abstractmethod
async def delete_collection(self, collection: str):
pass

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
import chromadb
import asyncio
from typing import Any
from chromadb import PersistentClient
from pkg.vector.vdb import VectorDatabase
from pkg.core import app
import chromadb
class ChromaVectorDatabase(VectorDatabase):
@@ -12,26 +13,29 @@ class ChromaVectorDatabase(VectorDatabase):
self.client = PersistentClient(path=base_path)
self._collections = {}
def get_or_create_collection(self, collection: str) -> chromadb.Collection:
async def get_or_create_collection(self, collection: str) -> chromadb.Collection:
if collection not in self._collections:
self._collections[collection] = self.client.get_or_create_collection(name=collection)
self._collections[collection] = await asyncio.to_thread(
self.client.get_or_create_collection, name=collection
)
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
return self._collections[collection]
def add_embeddings(
async def add_embeddings(
self,
collection: str,
ids: list[str],
embeddings_list: list[list[float]],
metadatas: list[dict[str, Any]],
) -> None:
col = self.get_or_create_collection(collection)
col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
col = await self.get_or_create_collection(collection)
await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas)
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
col = self.get_or_create_collection(collection)
results = col.query(
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
col = await self.get_or_create_collection(collection)
results = await asyncio.to_thread(
col.query,
query_embeddings=query_embedding,
n_results=k,
include=['metadatas', 'distances', 'documents'],
@@ -39,7 +43,13 @@ class ChromaVectorDatabase(VectorDatabase):
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
return results
def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None:
col = self.get_or_create_collection(collection)
col.delete(where=where)
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
col = await self.get_or_create_collection(collection)
await asyncio.to_thread(col.delete, where={'file_id': file_id})
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}")
async def delete_collection(self, collection: str):
if collection in self._collections:
del self._collections[collection]
await asyncio.to_thread(self.client.delete_collection, name=collection)
self.ap.logger.info(f"Chroma collection '{collection}' deleted.")