perf: definitions

This commit is contained in:
Junyan Qin
2025-07-10 16:45:59 +08:00
parent ac03a2dceb
commit 75c3ddde19
6 changed files with 112 additions and 79 deletions

View File

@@ -1,13 +1,9 @@
import quart
from .. import group
import os # 导入 os 用于文件操作
from ... import group
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
class KnowledgeBaseRouterGroup(group.RouterGroup):
# 定义成功方法
def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response:
return quart.jsonify({'code': code, 'data': data or {}, 'msg': msg})
async def initialize(self) -> None:
@self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases')
async def handle_knowledge_bases() -> str:
@@ -51,7 +47,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid))
return self.success(code=0, msg='ok')
@self.route('/<knowledge_base_uuid>/files', methods=['GET'], endpoint='get_knowledge_base_files')
async def get_knowledge_base_files(knowledge_base_uuid: str) -> str:
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid))
@@ -68,14 +63,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
msg='ok',
)
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'], endpoint='delete_specific_file_in_kb')
async def delete_specific_file_in_kb(file_id: str) -> str:
await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id))
return self.success(code=0, msg='ok')
@self.route('/<knowledge_base_uuid>/files', methods=['POST'], endpoint='relate_file_with_kb')
async def relate_file_id_with_kb(knowledge_base_uuid:str,file_id: str) -> str:
async def relate_file_id_with_kb(knowledge_base_uuid: str, file_id: str) -> str:
if 'file' not in quart.request.files:
return self.http_status(400, -1, 'No file part in the request')
@@ -83,7 +77,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
file_id = json_data.get('file_id')
if not file_id:
return self.http_status(400, -1, 'File ID is required')
# 调用服务层方法将文件与知识库关联
await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id))
return self.success(code=0, data={}, msg='ok')
return self.success(code=0, data={}, msg='ok')

View File

@@ -14,11 +14,13 @@ from . import group
from .groups import provider as groups_provider
from .groups import platform as groups_platform
from .groups import pipelines as groups_pipelines
from .groups import knowledge as groups_knowledge
importutil.import_modules_in_pkg(groups)
importutil.import_modules_in_pkg(groups_provider)
importutil.import_modules_in_pkg(groups_platform)
importutil.import_modules_in_pkg(groups_pipelines)
importutil.import_modules_in_pkg(groups_knowledge)
class HTTPController:

View File

@@ -27,7 +27,7 @@ from ..storage import mgr as storagemgr
from ..utils import logcache
from . import taskmgr
from . import entities as core_entities
from pkg.rag.knowledge.RAG_Manager import RAG_Manager
from ..rag.knowledge import mgr as rag_mgr
class Application:
@@ -48,7 +48,6 @@ class Application:
model_mgr: llm_model_mgr.ModelManager = None
# TODO 移动到 pipeline 里
tool_mgr: llm_tool_mgr.ToolManager = None
@@ -101,7 +100,6 @@ class Application:
storage_mgr: storagemgr.StorageMgr = None
# ========= HTTP Services =========
user_service: user_service.UserService = None
@@ -114,8 +112,7 @@ class Application:
bot_service: bot_service.BotService = None
knowledge_base_service: RAG_Manager = None
knowledge_base_service: rag_mgr.RAGManager = None
def __init__(self):
pass

View File

@@ -9,7 +9,7 @@ from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr
from ...rag.knowledge import mgr as rag_mgr
from ...platform import botmgr as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
@@ -102,7 +102,7 @@ class BuildAppStage(stage.BootingStage):
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
ap.embedding_models_service = embedding_models_service_inst
knowledge_base_service_inst = knowledge_base_mgr(ap)
knowledge_base_service_inst = rag_mgr.RAGManager(ap)
await knowledge_base_service_inst.initialize_rag_system()
ap.knowledge_base_service = knowledge_base_service_inst

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import logging
import os
import asyncio
import json
import uuid
from pkg.rag.knowledge.services.parser import FileParser
from pkg.rag.knowledge.services.chunker import Chunker
@@ -14,7 +13,8 @@ from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from pkg.core import app
class RAG_Manager:
class RAGManager:
ap: app.Application
def __init__(self, ap: app.Application, logger: logging.Logger = None):
@@ -42,32 +42,54 @@ class RAG_Manager:
try:
model = EmbeddingModelFactory.create_model(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name
model_type=self.embedding_model_type, model_name_key=self.embedding_model_name
)
self.logger.info(
f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}"
)
self.logger.info(f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}")
except Exception as e:
self.logger.critical(f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}")
raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.")
self.logger.critical(
f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}"
)
raise RuntimeError('Failed to initialize RAG_Manager due to embedding model issues.')
self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}")
self.embedder = Embedder(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager)
self.retriever = Retriever(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager)
self.chroma_manager = ChromaIndexManager(
collection_name=f'rag_collection_{self.embedding_model_name.replace("-", "_")}'
)
self.embedder = Embedder(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager,
)
self.retriever = Retriever(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager,
)
async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5):
async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5):
"""
Creates a new knowledge base if it doesn't already exist.
"""
try:
if not self.embedding_model_type or not kb_name:
raise ValueError("Embedding model type and knowledge base name must be set before creating a knowledge base.")
raise ValueError(
'Embedding model type and knowledge base name must be set before creating a knowledge base.'
)
def _create_kb_sync():
session = SessionLocal()
try:
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
if not kb:
id = uuid.uuid4().int
new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k,id=id)
new_kb = KnowledgeBase(
name=kb_name,
description=kb_description,
embedding_model=embedding_model,
top_k=top_k,
id=id,
)
session.add(new_kb)
session.commit()
session.refresh(new_kb)
@@ -80,7 +102,7 @@ class RAG_Manager:
self.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True)
raise
finally:
session.close()
session.close()
return await asyncio.to_thread(_create_kb_sync)
except Exception as e:
@@ -92,15 +114,17 @@ class RAG_Manager:
Retrieves all knowledge bases from the database.
"""
try:
def _get_all_kbs_sync():
session = SessionLocal()
try:
return session.query(KnowledgeBase).all()
finally:
session.close()
return await asyncio.to_thread(_get_all_kbs_sync)
except Exception as e:
self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True)
self.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True)
return []
async def get_knowledge_base_by_id(self, kb_id: int):
@@ -108,15 +132,17 @@ class RAG_Manager:
Retrieves a specific knowledge base by its ID.
"""
try:
def _get_kb_sync(kb_id_param):
session = SessionLocal()
try:
return session.query(KnowledgeBase).filter_by(id=kb_id_param).first()
finally:
session.close()
return await asyncio.to_thread(_get_kb_sync, kb_id)
except Exception as e:
self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
self.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True)
return None
async def get_files_by_knowledge_base(self, kb_id: int):
@@ -124,15 +150,17 @@ class RAG_Manager:
Retrieves files associated with a specific knowledge base by querying the File table directly.
"""
try:
def _get_files_sync(kb_id_param):
session = SessionLocal()
try:
return session.query(File).filter_by(kb_id=kb_id_param).all()
finally:
session.close()
return await asyncio.to_thread(_get_files_sync, kb_id)
except Exception as e:
self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True)
self.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True)
return []
async def get_all_files(self):
@@ -141,23 +169,27 @@ class RAG_Manager:
with any specific knowledge base.
"""
try:
def _get_all_files_sync():
session = SessionLocal()
try:
return session.query(File).all()
finally:
session.close()
return await asyncio.to_thread(_get_all_files_sync)
except Exception as e:
self.logger.error(f"Error retrieving all files: {str(e)}", exc_info=True)
self.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True)
return []
async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"):
async def store_data(
self, file_path: str, kb_name: str, file_type: str, kb_description: str = 'Default knowledge base'
):
"""
Parses, chunks, embeds, and stores data from a given file into the RAG system.
Associates the file with a knowledge base using kb_id in the File table.
"""
self.logger.info(f"Starting data storage process for file: {file_path}")
self.logger.info(f'Starting data storage process for file: {file_path}')
session = SessionLocal()
file_obj = None
@@ -177,37 +209,43 @@ class RAG_Manager:
file_name = os.path.basename(file_path)
existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first()
if existing_file:
self.logger.warning(f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage.")
self.logger.warning(
f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage."
)
return
file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type)
session.add(file_obj)
session.commit()
session.refresh(file_obj)
self.logger.info(f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}")
self.logger.info(
f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}"
)
# 3. 解析文件内容
text = await self.parser.parse(file_path)
if not text:
self.logger.warning(f"No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.")
self.logger.warning(f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.')
session.delete(file_obj)
session.commit() # 提交删除操作
session.commit() # 提交删除操作
return
# 4. 分块并嵌入/存储块
chunks_texts = await self.chunker.chunk(text)
self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts)
self.logger.info(f"Data storage process completed for file: {file_path}")
self.logger.info(f'Data storage process completed for file: {file_path}')
except Exception as e:
session.rollback()
self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True)
self.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True)
if file_obj and file_obj.id:
try:
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id)
except Exception as chroma_e:
self.logger.warning(f"Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}")
self.logger.warning(
f'Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}'
)
raise
finally:
session.close()
@@ -219,7 +257,7 @@ class RAG_Manager:
self.logger.info(f"Starting data retrieval process for query: '{query}'")
try:
retrieved_chunks = await self.retriever.retrieve(query)
self.logger.info(f"Successfully retrieved {len(retrieved_chunks)} chunks for query.")
self.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.')
return retrieved_chunks
except Exception as e:
self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True)
@@ -230,32 +268,32 @@ class RAG_Manager:
Deletes all data associated with a specific file ID, including its chunks and vectors,
and the file record itself.
"""
self.logger.info(f"Starting data deletion process for file_id: {file_id}")
self.logger.info(f'Starting data deletion process for file_id: {file_id}')
session = SessionLocal()
try:
# 1. 从 ChromaDB 删除 embeddings
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id)
self.logger.info(f"Deleted embeddings from ChromaDB for file_id: {file_id}")
self.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}')
# 2. 删除与文件关联的 chunks 记录
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
for chunk in chunks_to_delete:
session.delete(chunk)
self.logger.info(f"Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}")
self.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}')
# 3. 删除文件记录本身
file_to_delete = session.query(File).filter_by(id=file_id).first()
if file_to_delete:
session.delete(file_to_delete)
self.logger.info(f"Deleted file record for file_id: {file_id}")
self.logger.info(f'Deleted file record for file_id: {file_id}')
else:
self.logger.warning(f"File with ID {file_id} not found in database. Skipping deletion of file record.")
self.logger.warning(f'File with ID {file_id} not found in database. Skipping deletion of file record.')
session.commit()
self.logger.info(f"Successfully completed data deletion for file_id: {file_id}")
self.logger.info(f'Successfully completed data deletion for file_id: {file_id}')
except Exception as e:
session.rollback()
self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True)
self.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True)
raise
finally:
session.close()
@@ -265,27 +303,27 @@ class RAG_Manager:
Deletes a knowledge base and all associated files, chunks, and vectors.
This involves querying for associated files and then deleting them.
"""
self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}")
session = SessionLocal() # 使用新的会话来获取 KB 和关联文件
self.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}')
session = SessionLocal() # 使用新的会话来获取 KB 和关联文件
try:
kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if not kb_to_delete:
self.logger.warning(f"Knowledge Base with ID {kb_id} not found.")
self.logger.warning(f'Knowledge Base with ID {kb_id} not found.')
return
# 获取所有关联的文件,通过 File 表的 kb_id 字段查询
files_to_delete = session.query(File).filter_by(kb_id=kb_id).all()
# 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话
session.close()
session.close()
# 遍历删除每个关联文件及其数据
for file_obj in files_to_delete:
try:
await self.delete_data_by_file_id(file_obj.id)
except Exception as file_del_e:
self.logger.error(f"Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}")
self.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}')
# 记录错误但继续,尝试删除其他文件
# 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身
@@ -296,12 +334,14 @@ class RAG_Manager:
if kb_final_delete:
session.delete(kb_final_delete)
session.commit()
self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}")
self.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}')
else:
self.logger.warning(f"Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.")
self.logger.warning(
f'Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.'
)
except Exception as kb_del_e:
session.rollback()
self.logger.error(f"Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}", exc_info=True)
self.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True)
raise
finally:
session.close()
@@ -310,57 +350,57 @@ class RAG_Manager:
# 如果在最初获取 KB 或文件列表时出错
if session.is_active:
session.rollback()
self.logger.error(f"Error during overall knowledge base deletion for ID {kb_id}: {str(e)}", exc_info=True)
self.logger.error(f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True)
raise
finally:
if session.is_active:
session.close()
async def get_file_content_by_file_id(self, file_id: str) -> str:
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
_, ext = os.path.splitext(file_id.lower())
ext = ext.lstrip('.')
try:
text = file_bytes.decode("utf-8")
text = file_bytes.decode('utf-8')
except UnicodeDecodeError:
return "[非文本文件或编码无法识别]"
return '[非文本文件或编码无法识别]'
if ext in ["txt", "md", "csv", "log", "py", "html"]:
if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']:
return text
else:
return f"[未知类型: .{ext}]"
return f'[未知类型: .{ext}]'
async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None:
"""
Associates a file with a knowledge base by updating the kb_id in the File table.
"""
self.logger.info(f"Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}")
self.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}')
session = SessionLocal()
try:
# 查询知识库是否存在
kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first()
if not kb:
self.logger.error(f"Knowledge Base with UUID {knowledge_base_uuid} not found.")
self.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.')
return
# 更新文件的 kb_id
file_to_update = session.query(File).filter_by(id=file_id).first()
if not file_to_update:
self.logger.error(f"File with ID {file_id} not found.")
self.logger.error(f'File with ID {file_id} not found.')
return
file_to_update.kb_id = kb.id
session.commit()
self.logger.info(f"Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}")
self.logger.info(
f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}'
)
except Exception as e:
session.rollback()
self.logger.error(f"Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}", exc_info=True)
self.logger.error(
f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}',
exc_info=True,
)
finally:
session.close()