From 14c161b73316e268e91b30f6705af9bec0652e6a Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Fri, 11 Jul 2025 18:14:03 +0800 Subject: [PATCH] fix: create knwoledge base issue --- pkg/entity/persistence/rag.py | 26 ++++++++++---------------- pkg/rag/knowledge/mgr.py | 34 +++++++++++++++++----------------- 2 files changed, 27 insertions(+), 33 deletions(-) diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py index 1657196a..95a78712 100644 --- a/pkg/entity/persistence/rag.py +++ b/pkg/entity/persistence/rag.py @@ -1,19 +1,17 @@ -from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, LargeBinary +from sqlalchemy import create_engine, Column, String, Text, DateTime, LargeBinary, Integer from sqlalchemy.orm import declarative_base, sessionmaker from datetime import datetime import os - Base = declarative_base() DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') +print("Using database URL:", DATABASE_URL) engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False}) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - def create_db_and_tables(): """Creates all database tables defined in the Base.""" Base.metadata.create_all(bind=engine) @@ -22,35 +20,31 @@ def create_db_and_tables(): class KnowledgeBase(Base): __tablename__ = 'kb' - id = Column(Integer, primary_key=True, index=True) + id = Column(String, primary_key=True, index=True) name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) embedding_model_uuid = Column(String, default='') top_k = Column(Integer, default=5) - class File(Base): __tablename__ = 'file' - id = Column(Integer, primary_key=True, index=True) - kb_id = Column(Integer, nullable=True) + id = Column(String, primary_key=True, index=True) + kb_id = Column(String, nullable=True) file_name = Column(String) path = Column(String) created_at = Column(DateTime, default=datetime.utcnow) file_type = Column(String) - status = Column(Integer, default=0) - + status = Column(String, default='0') class Chunk(Base): __tablename__ = 'chunks' - id = Column(Integer, primary_key=True, index=True) - file_id = Column(Integer, nullable=True) - + id = Column(String, primary_key=True, index=True) + file_id = Column(String, nullable=True) text = Column(Text) - class Vector(Base): __tablename__ = 'vectors' - id = Column(Integer, primary_key=True, index=True) - chunk_id = Column(Integer, nullable=True) + id = Column(String, primary_key=True, index=True) + chunk_id = Column(String, nullable=True) embedding = Column(LargeBinary) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 4da10a09..585a5075 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -41,7 +41,7 @@ class RAGManager: try: kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: - id = uuid.uuid4().int + id = str(uuid.uuid4()) new_kb = KnowledgeBase( name=kb_name, description=kb_description, @@ -86,7 +86,7 @@ class RAGManager: self.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True) return [] - async def get_knowledge_base_by_id(self, kb_id: int): + async def get_knowledge_base_by_id(self, kb_id: str): """ Retrieves a specific knowledge base by its ID. """ @@ -104,7 +104,7 @@ class RAGManager: self.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True) return None - async def get_files_by_knowledge_base(self, kb_id: int): + async def get_files_by_knowledge_base(self, kb_id: str): """ Retrieves files associated with a specific knowledge base by querying the File table directly. """ @@ -153,7 +153,7 @@ class RAGManager: file_obj = None try: - # 1. 确保知识库存在或创建它 + kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: kb = KnowledgeBase(name=kb_name, description=kb_description) @@ -164,7 +164,7 @@ class RAGManager: else: self.logger.info(f"Knowledge Base '{kb_name}' already exists.") - # 2. 添加文件记录到数据库,并直接关联 kb_id + file_name = os.path.basename(file_path) existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() if existing_file: @@ -181,15 +181,15 @@ class RAGManager: f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}" ) - # 3. 解析文件内容 + text = await self.parser.parse(file_path) if not text: self.logger.warning(f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.') session.delete(file_obj) - session.commit() # 提交删除操作 + session.commit() return - # 4. 分块并嵌入/存储块 + chunks_texts = await self.chunker.chunk(text) self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) @@ -222,7 +222,7 @@ class RAGManager: self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) return [] - async def delete_data_by_file_id(self, file_id: int): + async def delete_data_by_file_id(self, file_id: str): """ Deletes all data associated with a specific file ID, including its chunks and vectors, and the file record itself. @@ -257,13 +257,13 @@ class RAGManager: finally: session.close() - async def delete_kb_by_id(self, kb_id: int): + async def delete_kb_by_id(self, kb_id: str): """ Deletes a knowledge base and all associated files, chunks, and vectors. This involves querying for associated files and then deleting them. """ self.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}') - session = SessionLocal() # 使用新的会话来获取 KB 和关联文件 + session = SessionLocal() try: kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() @@ -271,24 +271,24 @@ class RAGManager: self.logger.warning(f'Knowledge Base with ID {kb_id} not found.') return - # 获取所有关联的文件,通过 File 表的 kb_id 字段查询 + files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() - # 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话 + session.close() - # 遍历删除每个关联文件及其数据 + for file_obj in files_to_delete: try: await self.delete_data_by_file_id(file_obj.id) except Exception as file_del_e: self.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}') - # 记录错误但继续,尝试删除其他文件 + - # 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身 + session = SessionLocal() try: - # 重新查询,确保对象是当前会话的一部分 + kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() if kb_final_delete: session.delete(kb_final_delete)