diff --git a/pkg/api/http/controller/groups/knowledge/__init__.py b/pkg/api/http/controller/groups/knowledge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge/base.py similarity index 90% rename from pkg/api/http/controller/groups/knowledge_base.py rename to pkg/api/http/controller/groups/knowledge/base.py index ce391042..cf5bb44e 100644 --- a/pkg/api/http/controller/groups/knowledge_base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -1,13 +1,9 @@ import quart -from .. import group -import os # 导入 os 用于文件操作 +from ... import group + @group.group_class('knowledge_base', '/api/v1/knowledge/bases') class KnowledgeBaseRouterGroup(group.RouterGroup): - # 定义成功方法 - def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response: - return quart.jsonify({'code': code, 'data': data or {}, 'msg': msg}) - async def initialize(self) -> None: @self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases') async def handle_knowledge_bases() -> str: @@ -51,7 +47,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid)) return self.success(code=0, msg='ok') - @self.route('//files', methods=['GET'], endpoint='get_knowledge_base_files') async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid)) @@ -68,14 +63,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): msg='ok', ) - @self.route('//files/', methods=['DELETE'], endpoint='delete_specific_file_in_kb') async def delete_specific_file_in_kb(file_id: str) -> str: await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id)) return self.success(code=0, msg='ok') - + @self.route('//files', methods=['POST'], endpoint='relate_file_with_kb') - async def relate_file_id_with_kb(knowledge_base_uuid:str,file_id: str) -> str: + async def relate_file_id_with_kb(knowledge_base_uuid: str, file_id: str) -> str: if 'file' not in quart.request.files: return self.http_status(400, -1, 'No file part in the request') @@ -83,7 +77,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): file_id = json_data.get('file_id') if not file_id: return self.http_status(400, -1, 'File ID is required') - + # 调用服务层方法将文件与知识库关联 await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id)) - return self.success(code=0, data={}, msg='ok') \ No newline at end of file + return self.success(code=0, data={}, msg='ok') diff --git a/pkg/api/http/controller/main.py b/pkg/api/http/controller/main.py index eb434d88..4eec4e1d 100644 --- a/pkg/api/http/controller/main.py +++ b/pkg/api/http/controller/main.py @@ -14,11 +14,13 @@ from . import group from .groups import provider as groups_provider from .groups import platform as groups_platform from .groups import pipelines as groups_pipelines +from .groups import knowledge as groups_knowledge importutil.import_modules_in_pkg(groups) importutil.import_modules_in_pkg(groups_provider) importutil.import_modules_in_pkg(groups_platform) importutil.import_modules_in_pkg(groups_pipelines) +importutil.import_modules_in_pkg(groups_knowledge) class HTTPController: diff --git a/pkg/core/app.py b/pkg/core/app.py index 2e3c9500..11d25826 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -27,7 +27,7 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities -from pkg.rag.knowledge.RAG_Manager import RAG_Manager +from ..rag.knowledge import mgr as rag_mgr class Application: @@ -48,7 +48,6 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None - # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None @@ -101,7 +100,6 @@ class Application: storage_mgr: storagemgr.StorageMgr = None - # ========= HTTP Services ========= user_service: user_service.UserService = None @@ -114,8 +112,7 @@ class Application: bot_service: bot_service.BotService = None - knowledge_base_service: RAG_Manager = None - + knowledge_base_service: rag_mgr.RAGManager = None def __init__(self): pass diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index bb86a6d3..ac76c331 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -9,7 +9,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.tools import toolmgr as llm_tool_mgr -from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr +from ...rag.knowledge import mgr as rag_mgr from ...platform import botmgr as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -102,7 +102,7 @@ class BuildAppStage(stage.BootingStage): embedding_models_service_inst = model_service.EmbeddingModelsService(ap) ap.embedding_models_service = embedding_models_service_inst - knowledge_base_service_inst = knowledge_base_mgr(ap) + knowledge_base_service_inst = rag_mgr.RAGManager(ap) await knowledge_base_service_inst.initialize_rag_system() ap.knowledge_base_service = knowledge_base_service_inst diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/mgr.py similarity index 67% rename from pkg/rag/knowledge/RAG_Manager.py rename to pkg/rag/knowledge/mgr.py index 9675371b..7d1787e0 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/mgr.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging import os import asyncio -import json import uuid from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker @@ -14,7 +13,8 @@ from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager from pkg.core import app -class RAG_Manager: + +class RAGManager: ap: app.Application def __init__(self, ap: app.Application, logger: logging.Logger = None): @@ -42,32 +42,54 @@ class RAG_Manager: try: model = EmbeddingModelFactory.create_model( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name + model_type=self.embedding_model_type, model_name_key=self.embedding_model_name + ) + self.logger.info( + f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}" ) - self.logger.info(f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}") except Exception as e: - self.logger.critical(f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}") - raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.") + self.logger.critical( + f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}" + ) + raise RuntimeError('Failed to initialize RAG_Manager due to embedding model issues.') - self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}") - self.embedder = Embedder(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager) - self.retriever = Retriever(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager) + self.chroma_manager = ChromaIndexManager( + collection_name=f'rag_collection_{self.embedding_model_name.replace("-", "_")}' + ) + self.embedder = Embedder( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager, + ) + self.retriever = Retriever( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager, + ) - async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5): + async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5): """ Creates a new knowledge base if it doesn't already exist. """ try: if not self.embedding_model_type or not kb_name: - raise ValueError("Embedding model type and knowledge base name must be set before creating a knowledge base.") + raise ValueError( + 'Embedding model type and knowledge base name must be set before creating a knowledge base.' + ) + def _create_kb_sync(): session = SessionLocal() try: kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: id = uuid.uuid4().int - new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k,id=id) + new_kb = KnowledgeBase( + name=kb_name, + description=kb_description, + embedding_model=embedding_model, + top_k=top_k, + id=id, + ) session.add(new_kb) session.commit() session.refresh(new_kb) @@ -80,7 +102,7 @@ class RAG_Manager: self.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True) raise finally: - session.close() + session.close() return await asyncio.to_thread(_create_kb_sync) except Exception as e: @@ -92,15 +114,17 @@ class RAG_Manager: Retrieves all knowledge bases from the database. """ try: + def _get_all_kbs_sync(): session = SessionLocal() try: return session.query(KnowledgeBase).all() finally: session.close() + return await asyncio.to_thread(_get_all_kbs_sync) except Exception as e: - self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True) + self.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True) return [] async def get_knowledge_base_by_id(self, kb_id: int): @@ -108,15 +132,17 @@ class RAG_Manager: Retrieves a specific knowledge base by its ID. """ try: + def _get_kb_sync(kb_id_param): session = SessionLocal() try: return session.query(KnowledgeBase).filter_by(id=kb_id_param).first() finally: session.close() + return await asyncio.to_thread(_get_kb_sync, kb_id) except Exception as e: - self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + self.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True) return None async def get_files_by_knowledge_base(self, kb_id: int): @@ -124,15 +150,17 @@ class RAG_Manager: Retrieves files associated with a specific knowledge base by querying the File table directly. """ try: + def _get_files_sync(kb_id_param): session = SessionLocal() try: return session.query(File).filter_by(kb_id=kb_id_param).all() finally: session.close() + return await asyncio.to_thread(_get_files_sync, kb_id) except Exception as e: - self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True) + self.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True) return [] async def get_all_files(self): @@ -141,23 +169,27 @@ class RAG_Manager: with any specific knowledge base. """ try: + def _get_all_files_sync(): session = SessionLocal() try: return session.query(File).all() finally: session.close() + return await asyncio.to_thread(_get_all_files_sync) except Exception as e: - self.logger.error(f"Error retrieving all files: {str(e)}", exc_info=True) + self.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True) return [] - async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"): + async def store_data( + self, file_path: str, kb_name: str, file_type: str, kb_description: str = 'Default knowledge base' + ): """ Parses, chunks, embeds, and stores data from a given file into the RAG system. Associates the file with a knowledge base using kb_id in the File table. """ - self.logger.info(f"Starting data storage process for file: {file_path}") + self.logger.info(f'Starting data storage process for file: {file_path}') session = SessionLocal() file_obj = None @@ -177,37 +209,43 @@ class RAG_Manager: file_name = os.path.basename(file_path) existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() if existing_file: - self.logger.warning(f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage.") + self.logger.warning( + f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage." + ) return file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type) session.add(file_obj) session.commit() session.refresh(file_obj) - self.logger.info(f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}") + self.logger.info( + f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}" + ) # 3. 解析文件内容 text = await self.parser.parse(file_path) if not text: - self.logger.warning(f"No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.") + self.logger.warning(f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.') session.delete(file_obj) - session.commit() # 提交删除操作 + session.commit() # 提交删除操作 return # 4. 分块并嵌入/存储块 chunks_texts = await self.chunker.chunk(text) self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) - self.logger.info(f"Data storage process completed for file: {file_path}") + self.logger.info(f'Data storage process completed for file: {file_path}') except Exception as e: session.rollback() - self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True) + self.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True) if file_obj and file_obj.id: try: await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id) except Exception as chroma_e: - self.logger.warning(f"Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}") + self.logger.warning( + f'Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}' + ) raise finally: session.close() @@ -219,7 +257,7 @@ class RAG_Manager: self.logger.info(f"Starting data retrieval process for query: '{query}'") try: retrieved_chunks = await self.retriever.retrieve(query) - self.logger.info(f"Successfully retrieved {len(retrieved_chunks)} chunks for query.") + self.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.') return retrieved_chunks except Exception as e: self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) @@ -230,32 +268,32 @@ class RAG_Manager: Deletes all data associated with a specific file ID, including its chunks and vectors, and the file record itself. """ - self.logger.info(f"Starting data deletion process for file_id: {file_id}") + self.logger.info(f'Starting data deletion process for file_id: {file_id}') session = SessionLocal() try: # 1. 从 ChromaDB 删除 embeddings await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) - self.logger.info(f"Deleted embeddings from ChromaDB for file_id: {file_id}") + self.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') # 2. 删除与文件关联的 chunks 记录 chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() for chunk in chunks_to_delete: session.delete(chunk) - self.logger.info(f"Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}") + self.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}') # 3. 删除文件记录本身 file_to_delete = session.query(File).filter_by(id=file_id).first() if file_to_delete: session.delete(file_to_delete) - self.logger.info(f"Deleted file record for file_id: {file_id}") + self.logger.info(f'Deleted file record for file_id: {file_id}') else: - self.logger.warning(f"File with ID {file_id} not found in database. Skipping deletion of file record.") + self.logger.warning(f'File with ID {file_id} not found in database. Skipping deletion of file record.') session.commit() - self.logger.info(f"Successfully completed data deletion for file_id: {file_id}") + self.logger.info(f'Successfully completed data deletion for file_id: {file_id}') except Exception as e: session.rollback() - self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True) + self.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True) raise finally: session.close() @@ -265,27 +303,27 @@ class RAG_Manager: Deletes a knowledge base and all associated files, chunks, and vectors. This involves querying for associated files and then deleting them. """ - self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}") - session = SessionLocal() # 使用新的会话来获取 KB 和关联文件 + self.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}') + session = SessionLocal() # 使用新的会话来获取 KB 和关联文件 try: kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() if not kb_to_delete: - self.logger.warning(f"Knowledge Base with ID {kb_id} not found.") + self.logger.warning(f'Knowledge Base with ID {kb_id} not found.') return # 获取所有关联的文件,通过 File 表的 kb_id 字段查询 files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() - + # 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话 - session.close() + session.close() # 遍历删除每个关联文件及其数据 for file_obj in files_to_delete: try: await self.delete_data_by_file_id(file_obj.id) except Exception as file_del_e: - self.logger.error(f"Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}") + self.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}') # 记录错误但继续,尝试删除其他文件 # 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身 @@ -296,12 +334,14 @@ class RAG_Manager: if kb_final_delete: session.delete(kb_final_delete) session.commit() - self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + self.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}') else: - self.logger.warning(f"Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.") + self.logger.warning( + f'Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.' + ) except Exception as kb_del_e: session.rollback() - self.logger.error(f"Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}", exc_info=True) + self.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True) raise finally: session.close() @@ -310,57 +350,57 @@ class RAG_Manager: # 如果在最初获取 KB 或文件列表时出错 if session.is_active: session.rollback() - self.logger.error(f"Error during overall knowledge base deletion for ID {kb_id}: {str(e)}", exc_info=True) + self.logger.error(f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True) raise finally: if session.is_active: session.close() - - async def get_file_content_by_file_id(self, file_id: str) -> str: - file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) _, ext = os.path.splitext(file_id.lower()) ext = ext.lstrip('.') try: - text = file_bytes.decode("utf-8") + text = file_bytes.decode('utf-8') except UnicodeDecodeError: - return "[非文本文件或编码无法识别]" + return '[非文本文件或编码无法识别]' - if ext in ["txt", "md", "csv", "log", "py", "html"]: + if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']: return text else: - return f"[未知类型: .{ext}]" - + return f'[未知类型: .{ext}]' + async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None: """ Associates a file with a knowledge base by updating the kb_id in the File table. """ - self.logger.info(f"Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}") + self.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}') session = SessionLocal() try: # 查询知识库是否存在 kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first() if not kb: - self.logger.error(f"Knowledge Base with UUID {knowledge_base_uuid} not found.") + self.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') return # 更新文件的 kb_id file_to_update = session.query(File).filter_by(id=file_id).first() if not file_to_update: - self.logger.error(f"File with ID {file_id} not found.") + self.logger.error(f'File with ID {file_id} not found.') return file_to_update.kb_id = kb.id session.commit() - self.logger.info(f"Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}") + self.logger.info( + f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}' + ) except Exception as e: session.rollback() - self.logger.error(f"Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}", exc_info=True) + self.logger.error( + f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}', + exc_info=True, + ) finally: session.close() - -