mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat(rag): all APIs ok
This commit is contained in:
@@ -84,14 +84,29 @@ class KnowledgeService:
|
||||
|
||||
async def delete_file(self, kb_uuid: str, file_id: str) -> None:
|
||||
"""删除文件"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
|
||||
)
|
||||
# TODO: remove from memory
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
await runtime_kb.delete_file(file_id)
|
||||
|
||||
async def delete_knowledge_base(self, kb_uuid: str) -> None:
|
||||
"""删除知识库"""
|
||||
await self.ap.rag_mgr.remove_knowledge_base(kb_uuid)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
# TODO: remove from memory
|
||||
|
||||
# delete files
|
||||
files = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
|
||||
)
|
||||
for file in files:
|
||||
# delete chunks
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file.uuid)
|
||||
)
|
||||
# delete file
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
from .services import parser, chunker
|
||||
@@ -130,8 +129,21 @@ class RuntimeKnowledgeBase:
|
||||
)
|
||||
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model)
|
||||
|
||||
async def delete_file(self, file_id: str):
|
||||
# delete vector
|
||||
await self.ap.vector_db_mgr.vector_db.delete_by_file_id(self.knowledge_base_entity.uuid, file_id)
|
||||
|
||||
# delete chunk
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file_id)
|
||||
)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
|
||||
)
|
||||
|
||||
async def dispose(self):
|
||||
pass
|
||||
await self.ap.vector_db_mgr.vector_db.delete_collection(self.knowledge_base_entity.uuid)
|
||||
|
||||
|
||||
class RAGManager:
|
||||
@@ -192,118 +204,3 @@ class RAGManager:
|
||||
await kb.dispose()
|
||||
self.knowledge_bases.remove(kb)
|
||||
return
|
||||
|
||||
async def delete_data_by_file_id(self, file_id: str):
|
||||
"""
|
||||
Deletes all data associated with a specific file ID, including its chunks and vectors,
|
||||
and the file record itself.
|
||||
"""
|
||||
self.ap.logger.info(f'Starting data deletion process for file_id: {file_id}')
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# delete vectors
|
||||
await asyncio.to_thread(self.ap.vector_db_mgr.vector_db.delete_by_file_id_sync, file_id)
|
||||
self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}')
|
||||
|
||||
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
|
||||
for chunk in chunks_to_delete:
|
||||
session.delete(chunk)
|
||||
self.ap.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}')
|
||||
|
||||
file_to_delete = session.query(File).filter_by(id=file_id).first()
|
||||
if file_to_delete:
|
||||
session.delete(file_to_delete)
|
||||
try:
|
||||
await self.ap.storage_mgr.storage_provider.delete(file_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Error deleting file from storage for file_id {file_id}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
self.ap.logger.info(f'Deleted file record for file_id: {file_id}')
|
||||
else:
|
||||
self.ap.logger.warning(
|
||||
f'File with ID {file_id} not found in database. Skipping deletion of file record.'
|
||||
)
|
||||
session.commit()
|
||||
self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}')
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def delete_kb_by_id(self, kb_id: str):
|
||||
"""
|
||||
Deletes a knowledge base and all associated files, chunks, and vectors.
|
||||
This involves querying for associated files and then deleting them.
|
||||
"""
|
||||
self.ap.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}')
|
||||
session = SessionLocal()
|
||||
|
||||
try:
|
||||
kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if not kb_to_delete:
|
||||
self.ap.logger.warning(f'Knowledge Base with ID {kb_id} not found.')
|
||||
return
|
||||
|
||||
files_to_delete = session.query(File).filter_by(kb_id=kb_id).all()
|
||||
|
||||
session.close()
|
||||
|
||||
for file_obj in files_to_delete:
|
||||
try:
|
||||
await self.delete_data_by_file_id(file_obj.id)
|
||||
except Exception as file_del_e:
|
||||
self.ap.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}')
|
||||
|
||||
session = SessionLocal()
|
||||
try:
|
||||
kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if kb_final_delete:
|
||||
session.delete(kb_final_delete)
|
||||
session.commit()
|
||||
self.ap.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}')
|
||||
else:
|
||||
self.ap.logger.warning(
|
||||
f'Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.'
|
||||
)
|
||||
except Exception as kb_del_e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(
|
||||
f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}',
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
# 如果在最初获取 KB 或文件列表时出错
|
||||
if session.is_active:
|
||||
session.rollback()
|
||||
self.ap.logger.error(
|
||||
f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if session.is_active:
|
||||
session.close()
|
||||
|
||||
# async def get_file_content_by_file_id(self, file_id: str) -> str:
|
||||
# file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
|
||||
|
||||
# _, ext = os.path.splitext(file_id.lower())
|
||||
# ext = ext.lstrip('.')
|
||||
|
||||
# try:
|
||||
# text = file_bytes.decode('utf-8')
|
||||
# except UnicodeDecodeError:
|
||||
# return '[非文本文件或编码无法识别]'
|
||||
|
||||
# if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']:
|
||||
# return text
|
||||
# else:
|
||||
# return f'[未知类型: .{ext}]'
|
||||
|
||||
@@ -40,13 +40,7 @@ class Embedder(BaseService):
|
||||
)
|
||||
|
||||
# save embeddings to vdb
|
||||
await self._run_sync(
|
||||
self.ap.vector_db_mgr.vector_db.add_embeddings,
|
||||
kb_id,
|
||||
chunk_ids,
|
||||
embeddings_list,
|
||||
chunk_dicts,
|
||||
)
|
||||
await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts)
|
||||
|
||||
self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')
|
||||
|
||||
|
||||
@@ -14,7 +14,9 @@ class Retriever(base_service.BaseService):
|
||||
async def retrieve(
|
||||
self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
|
||||
) -> list[retriever_entities.RetrieveResultEntry]:
|
||||
self.ap.logger.info(f"Retrieving for query: '{query}' with k={k} using {embedding_model.model_entity.uuid}")
|
||||
self.ap.logger.info(
|
||||
f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}"
|
||||
)
|
||||
|
||||
query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
@@ -22,7 +24,7 @@ class Retriever(base_service.BaseService):
|
||||
extra_args={}, # TODO: add extra args
|
||||
)
|
||||
|
||||
chroma_results = await self._run_sync(self.ap.vector_db_mgr.vector_db.search, kb_id, query_embedding[0], k)
|
||||
chroma_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k)
|
||||
|
||||
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
|
||||
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
class VectorDatabase(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def add_embeddings(
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
@@ -18,16 +18,20 @@ class VectorDatabase(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
|
||||
"""在指定 collection 中检索最相似的向量。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None:
|
||||
"""根据元数据删除指定 collection 中的向量。"""
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
"""根据 file_id 删除指定 collection 中的向量。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_or_create_collection(self, collection: str):
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""获取或创建 collection。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_collection(self, collection: str):
|
||||
pass
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import chromadb
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from chromadb import PersistentClient
|
||||
from pkg.vector.vdb import VectorDatabase
|
||||
from pkg.core import app
|
||||
import chromadb
|
||||
|
||||
|
||||
class ChromaVectorDatabase(VectorDatabase):
|
||||
@@ -12,26 +13,29 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
self.client = PersistentClient(path=base_path)
|
||||
self._collections = {}
|
||||
|
||||
def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
||||
async def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
||||
if collection not in self._collections:
|
||||
self._collections[collection] = self.client.get_or_create_collection(name=collection)
|
||||
self._collections[collection] = await asyncio.to_thread(
|
||||
self.client.get_or_create_collection, name=collection
|
||||
)
|
||||
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
|
||||
return self._collections[collection]
|
||||
|
||||
def add_embeddings(
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
) -> None:
|
||||
col = self.get_or_create_collection(collection)
|
||||
col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||
col = await self.get_or_create_collection(collection)
|
||||
await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
|
||||
|
||||
def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
||||
col = self.get_or_create_collection(collection)
|
||||
results = col.query(
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
results = await asyncio.to_thread(
|
||||
col.query,
|
||||
query_embeddings=query_embedding,
|
||||
n_results=k,
|
||||
include=['metadatas', 'distances', 'documents'],
|
||||
@@ -39,7 +43,13 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
|
||||
return results
|
||||
|
||||
def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None:
|
||||
col = self.get_or_create_collection(collection)
|
||||
col.delete(where=where)
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
await asyncio.to_thread(col.delete, where={'file_id': file_id})
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
await asyncio.to_thread(self.client.delete_collection, name=collection)
|
||||
self.ap.logger.info(f"Chroma collection '{collection}' deleted.")
|
||||
|
||||
Reference in New Issue
Block a user