From ac03a2dceb1bcb7da5e10571630ee85d47079a58 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Wed, 9 Jul 2025 22:09:46 +0800 Subject: [PATCH] feat: modify the rag.py --- .../http/controller/groups/knowledge_base.py | 72 ++-- pkg/entity/persistence/rag.py | 58 +++ pkg/rag/knowledge/RAG_Manager.py | 354 +++++++++++------- pkg/rag/knowledge/services/database.py | 83 ++-- 4 files changed, 338 insertions(+), 229 deletions(-) create mode 100644 pkg/entity/persistence/rag.py diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py index e9606a3d..ce391042 100644 --- a/pkg/api/http/controller/groups/knowledge_base.py +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -1,6 +1,6 @@ import quart from .. import group - +import os # 导入 os 用于文件操作 @group.group_class('knowledge_base', '/api/v1/knowledge/bases') class KnowledgeBaseRouterGroup(group.RouterGroup): @@ -9,8 +9,8 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): return quart.jsonify({'code': code, 'data': data or {}, 'msg': msg}) async def initialize(self) -> None: - @self.route('', methods=['POST', 'GET']) - async def _() -> str: + @self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases') + async def handle_knowledge_bases() -> str: if quart.request.method == 'GET': knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases() bases_list = [ @@ -23,17 +23,17 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): ] return self.success(code=0, data={'bases': bases_list}, msg='ok') + # POST: create a new knowledge base json_data = await quart.request.json knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( json_data.get('name'), json_data.get('description') ) - _ = knowledge_base_uuid - return self.success(code=0, data={}, msg='ok') + return self.success(code=0, data={'uuid': knowledge_base_uuid}, msg='ok') - @self.route('/', methods=['GET', 'DELETE']) - async def _(knowledge_base_uuid: str) -> str: + @self.route('/', methods=['GET', 'DELETE'], endpoint='handle_specific_knowledge_base') + async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid) + knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(int(knowledge_base_uuid)) if knowledge_base is None: return self.http_status(404, -1, 'knowledge base not found') @@ -48,28 +48,42 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): msg='ok', ) elif quart.request.method == 'DELETE': - await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid) + await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid)) return self.success(code=0, msg='ok') - @self.route('//files', methods=['GET']) - async def _(knowledge_base_uuid: str) -> str: - if quart.request.method == 'GET': - files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) - return self.success( - code=0, - data=[ - { - 'id': file.id, - 'file_name': file.file_name, - 'status': file.status, - } - for file in files - ], - msg='ok', - ) - # delete specific file in knowledge base - @self.route('//files/', methods=['DELETE']) - async def _(knowledge_base_uuid: str, file_id: str) -> str: - await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) + @self.route('//files', methods=['GET'], endpoint='get_knowledge_base_files') + async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: + files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid)) + return self.success( + code=0, + data=[ + { + 'id': file.id, + 'file_name': file.file_name, + 'status': file.status, + } + for file in files + ], + msg='ok', + ) + + + @self.route('//files/', methods=['DELETE'], endpoint='delete_specific_file_in_kb') + async def delete_specific_file_in_kb(file_id: str) -> str: + await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id)) return self.success(code=0, msg='ok') + + @self.route('//files', methods=['POST'], endpoint='relate_file_with_kb') + async def relate_file_id_with_kb(knowledge_base_uuid:str,file_id: str) -> str: + if 'file' not in quart.request.files: + return self.http_status(400, -1, 'No file part in the request') + + json_data = await quart.request.json + file_id = json_data.get('file_id') + if not file_id: + return self.http_status(400, -1, 'File ID is required') + + # 调用服务层方法将文件与知识库关联 + await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id)) + return self.success(code=0, data={}, msg='ok') \ No newline at end of file diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py new file mode 100644 index 00000000..175720f1 --- /dev/null +++ b/pkg/entity/persistence/rag.py @@ -0,0 +1,58 @@ +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, LargeBinary +from sqlalchemy.orm import declarative_base, sessionmaker +from datetime import datetime +import os + + +Base = declarative_base() +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./rag_knowledge.db") + + +engine = create_engine( + DATABASE_URL, + connect_args={"check_same_thread": False} +) + + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def create_db_and_tables(): + """Creates all database tables defined in the Base.""" + Base.metadata.create_all(bind=engine) + print("Database tables created or already exist.") + +class KnowledgeBase(Base): + __tablename__ = 'kb' + id = Column(Integer, primary_key=True, index=True) + name = Column(String, index=True) + description = Column(Text) + created_at = Column(DateTime, default=datetime.utcnow) + embedding_model = Column(String, default='') + top_k = Column(Integer, default=5) + + +class File(Base): + __tablename__ = 'file' + id = Column(Integer, primary_key=True, index=True) + kb_id = Column(Integer, nullable=True) + file_name = Column(String) + path = Column(String) + created_at = Column(DateTime, default=datetime.utcnow) + file_type = Column(String) + status = Column(Integer, default=0) + + +class Chunk(Base): + __tablename__ = 'chunks' + id = Column(Integer, primary_key=True, index=True) + file_id = Column(Integer, nullable=True) + + text = Column(Text) + + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, nullable=True) + embedding = Column(LargeBinary) diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index 6ded737a..9675371b 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -1,38 +1,42 @@ -# RAG_Manager class (main class, adjust imports as needed) -from __future__ import annotations # For type hinting in Python 3.7+ +# rag_manager.py +from __future__ import annotations import logging import os import asyncio +import json +import uuid from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker from pkg.rag.knowledge.services.embedder import Embedder from pkg.rag.knowledge.services.retriever import Retriever -from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly +from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager -from pkg.core import app # Adjust the import path as needed - +from pkg.core import app class RAG_Manager: - ap: app.Application - def __init__(self, ap: app.Application,logger: logging.Logger = None): + def __init__(self, ap: app.Application, logger: logging.Logger = None): self.ap = ap self.logger = logger or logging.getLogger(__name__) self.embedding_model_type = None self.embedding_model_name = None self.chroma_manager = None - self.parser = None - self.chunker = None + self.parser = FileParser() + self.chunker = Chunker() self.embedder = None self.retriever = None async def initialize_rag_system(self): + """Initializes the RAG system by creating database tables.""" await asyncio.to_thread(create_db_and_tables) - async def create_specific_model(self, embedding_model_type: str, - embedding_model_name: str): + async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str): + """ + Creates and configures the specific embedding model and ChromaDB manager. + This must be called before performing embedding or retrieval operations. + """ self.embedding_model_type = embedding_model_type self.embedding_model_name = embedding_model_name @@ -47,52 +51,38 @@ class RAG_Manager: raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.") self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}") - - self.parser = FileParser() - self.chunker = Chunker() - # Pass chroma_manager to Embedder and Retriever - self.embedder = Embedder( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name, - chroma_manager=self.chroma_manager # Inject dependency - ) - self.retriever = Retriever( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name, - chroma_manager=self.chroma_manager # Inject dependency - ) - + self.embedder = Embedder(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager) + self.retriever = Retriever(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager) async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5): """ - Creates a new knowledge base with the given name and description. - If a knowledge base with the same name already exists, it returns that one. + Creates a new knowledge base if it doesn't already exist. """ try: - def _get_kb_sync(name): + if not self.embedding_model_type or not kb_name: + raise ValueError("Embedding model type and knowledge base name must be set before creating a knowledge base.") + def _create_kb_sync(): session = SessionLocal() try: - return session.query(KnowledgeBase).filter_by(name=name).first() - finally: - session.close() - - kb = await asyncio.to_thread(_get_kb_sync, kb_name) - - if not kb: - def _add_kb_sync(): - session = SessionLocal() - try: - new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k) + kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() + if not kb: + id = uuid.uuid4().int + new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k,id=id) session.add(new_kb) session.commit() session.refresh(new_kb) - return new_kb - finally: - session.close() - kb = await asyncio.to_thread(_add_kb_sync) - except Exception as e: - self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) - raise + self.logger.info(f"Knowledge Base '{kb_name}' created.") + return new_kb.id + else: + self.logger.info(f"Knowledge Base '{kb_name}' already exists.") + except Exception as e: + session.rollback() + self.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True) + raise + finally: + session.close() + + return await asyncio.to_thread(_create_kb_sync) except Exception as e: self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) raise @@ -108,116 +98,124 @@ class RAG_Manager: return session.query(KnowledgeBase).all() finally: session.close() - - kbs = await asyncio.to_thread(_get_all_kbs_sync) - return kbs + return await asyncio.to_thread(_get_all_kbs_sync) except Exception as e: self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True) return [] - + async def get_knowledge_base_by_id(self, kb_id: int): """ - Retrieves a knowledge base by its ID. + Retrieves a specific knowledge base by its ID. """ try: - def _get_kb_sync(kb_id): + def _get_kb_sync(kb_id_param): session = SessionLocal() try: - return session.query(KnowledgeBase).filter_by(id=kb_id).first() + return session.query(KnowledgeBase).filter_by(id=kb_id_param).first() finally: session.close() - - kb = await asyncio.to_thread(_get_kb_sync, kb_id) - return kb + return await asyncio.to_thread(_get_kb_sync, kb_id) except Exception as e: self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True) return None - + async def get_files_by_knowledge_base(self, kb_id: int): + """ + Retrieves files associated with a specific knowledge base by querying the File table directly. + """ try: - def _get_files_sync(kb_id): + def _get_files_sync(kb_id_param): session = SessionLocal() try: - return session.query(File).filter_by(kb_id=kb_id).all() + return session.query(File).filter_by(kb_id=kb_id_param).all() finally: session.close() - - files = await asyncio.to_thread(_get_files_sync, kb_id) - return files + return await asyncio.to_thread(_get_files_sync, kb_id) except Exception as e: self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True) return [] + async def get_all_files(self): + """ + Retrieves all files stored in the database, regardless of their association + with any specific knowledge base. + """ + try: + def _get_all_files_sync(): + session = SessionLocal() + try: + return session.query(File).all() + finally: + session.close() + return await asyncio.to_thread(_get_all_files_sync) + except Exception as e: + self.logger.error(f"Error retrieving all files: {str(e)}", exc_info=True) + return [] async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"): + """ + Parses, chunks, embeds, and stores data from a given file into the RAG system. + Associates the file with a knowledge base using kb_id in the File table. + """ self.logger.info(f"Starting data storage process for file: {file_path}") + session = SessionLocal() + file_obj = None + try: - def _get_kb_sync(name): - session = SessionLocal() - try: - return session.query(KnowledgeBase).filter_by(name=name).first() - finally: - session.close() - - kb = await asyncio.to_thread(_get_kb_sync, kb_name) - + # 1. 确保知识库存在或创建它 + kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: - self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.") - def _add_kb_sync(): - session = SessionLocal() - try: - new_kb = KnowledgeBase(name=kb_name, description=kb_description) - session.add(new_kb) - session.commit() - session.refresh(new_kb) - return new_kb - finally: - session.close() - kb = await asyncio.to_thread(_add_kb_sync) - self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})") + kb = KnowledgeBase(name=kb_name, description=kb_description) + session.add(kb) + session.commit() + session.refresh(kb) + self.logger.info(f"Knowledge Base '{kb_name}' created during store_data.") + else: + self.logger.info(f"Knowledge Base '{kb_name}' already exists.") - def _add_file_sync(kb_id, file_name, path, file_type): - session = SessionLocal() - try: - file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type) - session.add(file) - session.commit() - session.refresh(file) - return file - finally: - session.close() - - file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type) - self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})") - - text = await self.parser.parse(file_path) - if not text: - self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.") - # You might want to delete the file_obj from the DB here if it's empty. - session = SessionLocal() - try: - session.delete(file_obj) - session.commit() - except Exception as del_e: - self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}") - finally: - session.close() + # 2. 添加文件记录到数据库,并直接关联 kb_id + file_name = os.path.basename(file_path) + existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() + if existing_file: + self.logger.warning(f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage.") return + file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type) + session.add(file_obj) + session.commit() + session.refresh(file_obj) + self.logger.info(f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}") + + # 3. 解析文件内容 + text = await self.parser.parse(file_path) + if not text: + self.logger.warning(f"No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.") + session.delete(file_obj) + session.commit() # 提交删除操作 + return + + # 4. 分块并嵌入/存储块 chunks_texts = await self.chunker.chunk(text) - self.logger.info(f"Chunked into {len(chunks_texts)} pieces.") - - # embed_and_store now handles both DB chunk saving and Chroma embedding + self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) - self.logger.info(f"Data storage process completed for file: {file_path}") except Exception as e: + session.rollback() self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True) - # Consider cleaning up partially stored data if an error occurs. - return + if file_obj and file_obj.id: + try: + await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id) + except Exception as chroma_e: + self.logger.warning(f"Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}") + raise + finally: + session.close() async def retrieve_data(self, query: str): + """ + Retrieves relevant data chunks based on a given query using the configured retriever. + """ self.logger.info(f"Starting data retrieval process for query: '{query}'") try: retrieved_chunks = await self.retriever.retrieve(query) @@ -229,60 +227,140 @@ class RAG_Manager: async def delete_data_by_file_id(self, file_id: int): """ - Deletes data associated with a specific file_id from both the relational DB and Chroma. + Deletes all data associated with a specific file ID, including its chunks and vectors, + and the file record itself. """ self.logger.info(f"Starting data deletion process for file_id: {file_id}") session = SessionLocal() try: - # 1. Delete from Chroma + # 1. 从 ChromaDB 删除 embeddings await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) + self.logger.info(f"Deleted embeddings from ChromaDB for file_id: {file_id}") - # 2. Delete chunks from relational DB + # 2. 删除与文件关联的 chunks 记录 chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() for chunk in chunks_to_delete: session.delete(chunk) - self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.") + self.logger.info(f"Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}") - # 3. Delete file entry from relational DB + # 3. 删除文件记录本身 file_to_delete = session.query(File).filter_by(id=file_id).first() if file_to_delete: session.delete(file_to_delete) - self.logger.info(f"Deleted file entry {file_id} from relational DB.") + self.logger.info(f"Deleted file record for file_id: {file_id}") else: - self.logger.warning(f"File entry {file_id} not found in relational DB.") + self.logger.warning(f"File with ID {file_id} not found in database. Skipping deletion of file record.") session.commit() - self.logger.info(f"Data deletion completed for file_id: {file_id}.") + self.logger.info(f"Successfully completed data deletion for file_id: {file_id}") except Exception as e: session.rollback() self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True) + raise finally: session.close() async def delete_kb_by_id(self, kb_id: int): """ - Deletes a knowledge base and all associated files and chunks. + Deletes a knowledge base and all associated files, chunks, and vectors. + This involves querying for associated files and then deleting them. """ self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}") - session = SessionLocal() + session = SessionLocal() # 使用新的会话来获取 KB 和关联文件 + try: - # 1. Get the knowledge base - kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() - if not kb: + kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if not kb_to_delete: self.logger.warning(f"Knowledge Base with ID {kb_id} not found.") return - # 2. Delete all files associated with this knowledge base - files_to_delete = session.query(File).filter_by(kb_id=kb.id).all() - for file in files_to_delete: - await self.delete_data_by_file_id(file.id) + # 获取所有关联的文件,通过 File 表的 kb_id 字段查询 + files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() + + # 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话 + session.close() - # 3. Delete the knowledge base itself - session.delete(kb) + # 遍历删除每个关联文件及其数据 + for file_obj in files_to_delete: + try: + await self.delete_data_by_file_id(file_obj.id) + except Exception as file_del_e: + self.logger.error(f"Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}") + # 记录错误但继续,尝试删除其他文件 + + # 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身 + session = SessionLocal() + try: + # 重新查询,确保对象是当前会话的一部分 + kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if kb_final_delete: + session.delete(kb_final_delete) + session.commit() + self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + else: + self.logger.warning(f"Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.") + except Exception as kb_del_e: + session.rollback() + self.logger.error(f"Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}", exc_info=True) + raise + finally: + session.close() + + except Exception as e: + # 如果在最初获取 KB 或文件列表时出错 + if session.is_active: + session.rollback() + self.logger.error(f"Error during overall knowledge base deletion for ID {kb_id}: {str(e)}", exc_info=True) + raise + finally: + if session.is_active: + session.close() + + + + async def get_file_content_by_file_id(self, file_id: str) -> str: + + file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) + + _, ext = os.path.splitext(file_id.lower()) + ext = ext.lstrip('.') + + try: + text = file_bytes.decode("utf-8") + except UnicodeDecodeError: + return "[非文本文件或编码无法识别]" + + if ext in ["txt", "md", "csv", "log", "py", "html"]: + return text + else: + return f"[未知类型: .{ext}]" + + async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None: + """ + Associates a file with a knowledge base by updating the kb_id in the File table. + """ + self.logger.info(f"Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}") + session = SessionLocal() + try: + # 查询知识库是否存在 + kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first() + if not kb: + self.logger.error(f"Knowledge Base with UUID {knowledge_base_uuid} not found.") + return + + # 更新文件的 kb_id + file_to_update = session.query(File).filter_by(id=file_id).first() + if not file_to_update: + self.logger.error(f"File with ID {file_id} not found.") + return + + file_to_update.kb_id = kb.id session.commit() - self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + self.logger.info(f"Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}") except Exception as e: session.rollback() - self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + self.logger.error(f"Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}", exc_info=True) finally: session.close() + + diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py index 35a52453..bc5caa10 100644 --- a/pkg/rag/knowledge/services/database.py +++ b/pkg/rag/knowledge/services/database.py @@ -1,64 +1,23 @@ -from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary -from sqlalchemy.orm import declarative_base, sessionmaker, relationship -from datetime import datetime +# 全部迁移过去 -Base = declarative_base() +from pkg.entity.persistence.rag import ( + create_db_and_tables, + SessionLocal, + Base, + engine, + KnowledgeBase, + File, + Chunk, + Vector, +) - -class KnowledgeBase(Base): - __tablename__ = 'kb' - id = Column(Integer, primary_key=True, index=True) - name = Column(String, index=True) - description = Column(Text) - created_at = Column(DateTime, default=datetime.utcnow) - embedding_model = Column(String, default='') # 默认嵌入模型 - top_k = Column(Integer, default=5) # 默认返回的top_k数量 - files = relationship('File', back_populates='knowledge_base') - - -class File(Base): - __tablename__ = 'file' - id = Column(Integer, primary_key=True, index=True) - kb_id = Column(Integer, ForeignKey('kb.id')) - file_name = Column(String) - path = Column(String) - created_at = Column(DateTime, default=datetime.utcnow) - file_type = Column(String) - status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误 - knowledge_base = relationship('KnowledgeBase', back_populates='files') - chunks = relationship('Chunk', back_populates='file') - - -class Chunk(Base): - __tablename__ = 'chunks' - id = Column(Integer, primary_key=True, index=True) - file_id = Column(Integer, ForeignKey('file.id')) - text = Column(Text) - - file = relationship('File', back_populates='chunks') - vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one - - -class Vector(Base): - __tablename__ = 'vectors' - id = Column(Integer, primary_key=True, index=True) - chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) - embedding = Column(LargeBinary) # Store embeddings as binary - - chunk = relationship('Chunk', back_populates='vector') - - -# 数据库连接 -DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL -engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {}) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -# 创建所有表 (可以在应用启动时执行一次) -def create_db_and_tables(): - Base.metadata.create_all(bind=engine) - print('Database tables created/checked.') - - -# 定义嵌入维度(请根据你实际使用的模型调整) -EMBEDDING_DIM = 1024 +__all__ = [ + "create_db_and_tables", + "SessionLocal", + "Base", + "engine", + "KnowledgeBase", + "File", + "Chunk", + "Vector", +]