mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: modify the rag.py
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import quart
|
||||
from .. import group
|
||||
|
||||
import os # 导入 os 用于文件操作
|
||||
|
||||
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
|
||||
class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
@@ -9,8 +9,8 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
return quart.jsonify({'code': code, 'data': data or {}, 'msg': msg})
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['POST', 'GET'])
|
||||
async def _() -> str:
|
||||
@self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases')
|
||||
async def handle_knowledge_bases() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases()
|
||||
bases_list = [
|
||||
@@ -23,17 +23,17 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
]
|
||||
return self.success(code=0, data={'bases': bases_list}, msg='ok')
|
||||
|
||||
# POST: create a new knowledge base
|
||||
json_data = await quart.request.json
|
||||
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
|
||||
json_data.get('name'), json_data.get('description')
|
||||
)
|
||||
_ = knowledge_base_uuid
|
||||
return self.success(code=0, data={}, msg='ok')
|
||||
return self.success(code=0, data={'uuid': knowledge_base_uuid}, msg='ok')
|
||||
|
||||
@self.route('/<knowledge_base_uuid>', methods=['GET', 'DELETE'])
|
||||
async def _(knowledge_base_uuid: str) -> str:
|
||||
@self.route('/<knowledge_base_uuid>', methods=['GET', 'DELETE'], endpoint='handle_specific_knowledge_base')
|
||||
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid)
|
||||
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(int(knowledge_base_uuid))
|
||||
|
||||
if knowledge_base is None:
|
||||
return self.http_status(404, -1, 'knowledge base not found')
|
||||
@@ -48,28 +48,42 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
msg='ok',
|
||||
)
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid)
|
||||
await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid))
|
||||
return self.success(code=0, msg='ok')
|
||||
|
||||
@self.route('/<knowledge_base_uuid>/files', methods=['GET'])
|
||||
async def _(knowledge_base_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid)
|
||||
return self.success(
|
||||
code=0,
|
||||
data=[
|
||||
{
|
||||
'id': file.id,
|
||||
'file_name': file.file_name,
|
||||
'status': file.status,
|
||||
}
|
||||
for file in files
|
||||
],
|
||||
msg='ok',
|
||||
)
|
||||
|
||||
# delete specific file in knowledge base
|
||||
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'])
|
||||
async def _(knowledge_base_uuid: str, file_id: str) -> str:
|
||||
await self.ap.knowledge_base_service.delete_data_by_file_id(file_id)
|
||||
@self.route('/<knowledge_base_uuid>/files', methods=['GET'], endpoint='get_knowledge_base_files')
|
||||
async def get_knowledge_base_files(knowledge_base_uuid: str) -> str:
|
||||
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid))
|
||||
return self.success(
|
||||
code=0,
|
||||
data=[
|
||||
{
|
||||
'id': file.id,
|
||||
'file_name': file.file_name,
|
||||
'status': file.status,
|
||||
}
|
||||
for file in files
|
||||
],
|
||||
msg='ok',
|
||||
)
|
||||
|
||||
|
||||
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'], endpoint='delete_specific_file_in_kb')
|
||||
async def delete_specific_file_in_kb(file_id: str) -> str:
|
||||
await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id))
|
||||
return self.success(code=0, msg='ok')
|
||||
|
||||
@self.route('/<knowledge_base_uuid>/files', methods=['POST'], endpoint='relate_file_with_kb')
|
||||
async def relate_file_id_with_kb(knowledge_base_uuid:str,file_id: str) -> str:
|
||||
if 'file' not in quart.request.files:
|
||||
return self.http_status(400, -1, 'No file part in the request')
|
||||
|
||||
json_data = await quart.request.json
|
||||
file_id = json_data.get('file_id')
|
||||
if not file_id:
|
||||
return self.http_status(400, -1, 'File ID is required')
|
||||
|
||||
# 调用服务层方法将文件与知识库关联
|
||||
await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id))
|
||||
return self.success(code=0, data={}, msg='ok')
|
||||
58
pkg/entity/persistence/rag.py
Normal file
58
pkg/entity/persistence/rag.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, LargeBinary
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./rag_knowledge.db")
|
||||
|
||||
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def create_db_and_tables():
|
||||
"""Creates all database tables defined in the Base."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print("Database tables created or already exist.")
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'kb'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
embedding_model = Column(String, default='')
|
||||
top_k = Column(Integer, default=5)
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'file'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
kb_id = Column(Integer, nullable=True)
|
||||
file_name = Column(String)
|
||||
path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
file_type = Column(String)
|
||||
status = Column(Integer, default=0)
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
__tablename__ = 'chunks'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
file_id = Column(Integer, nullable=True)
|
||||
|
||||
text = Column(Text)
|
||||
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, nullable=True)
|
||||
embedding = Column(LargeBinary)
|
||||
@@ -1,38 +1,42 @@
|
||||
# RAG_Manager class (main class, adjust imports as needed)
|
||||
from __future__ import annotations # For type hinting in Python 3.7+
|
||||
# rag_manager.py
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
|
||||
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from pkg.core import app # Adjust the import path as needed
|
||||
|
||||
from pkg.core import app
|
||||
|
||||
class RAG_Manager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application,logger: logging.Logger = None):
|
||||
def __init__(self, ap: app.Application, logger: logging.Logger = None):
|
||||
self.ap = ap
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.embedding_model_type = None
|
||||
self.embedding_model_name = None
|
||||
self.chroma_manager = None
|
||||
self.parser = None
|
||||
self.chunker = None
|
||||
self.parser = FileParser()
|
||||
self.chunker = Chunker()
|
||||
self.embedder = None
|
||||
self.retriever = None
|
||||
|
||||
async def initialize_rag_system(self):
|
||||
"""Initializes the RAG system by creating database tables."""
|
||||
await asyncio.to_thread(create_db_and_tables)
|
||||
|
||||
async def create_specific_model(self, embedding_model_type: str,
|
||||
embedding_model_name: str):
|
||||
async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str):
|
||||
"""
|
||||
Creates and configures the specific embedding model and ChromaDB manager.
|
||||
This must be called before performing embedding or retrieval operations.
|
||||
"""
|
||||
self.embedding_model_type = embedding_model_type
|
||||
self.embedding_model_name = embedding_model_name
|
||||
|
||||
@@ -47,52 +51,38 @@ class RAG_Manager:
|
||||
raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.")
|
||||
|
||||
self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}")
|
||||
|
||||
self.parser = FileParser()
|
||||
self.chunker = Chunker()
|
||||
# Pass chroma_manager to Embedder and Retriever
|
||||
self.embedder = Embedder(
|
||||
model_type=self.embedding_model_type,
|
||||
model_name_key=self.embedding_model_name,
|
||||
chroma_manager=self.chroma_manager # Inject dependency
|
||||
)
|
||||
self.retriever = Retriever(
|
||||
model_type=self.embedding_model_type,
|
||||
model_name_key=self.embedding_model_name,
|
||||
chroma_manager=self.chroma_manager # Inject dependency
|
||||
)
|
||||
|
||||
self.embedder = Embedder(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager)
|
||||
self.retriever = Retriever(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager)
|
||||
|
||||
async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5):
|
||||
"""
|
||||
Creates a new knowledge base with the given name and description.
|
||||
If a knowledge base with the same name already exists, it returns that one.
|
||||
Creates a new knowledge base if it doesn't already exist.
|
||||
"""
|
||||
try:
|
||||
def _get_kb_sync(name):
|
||||
if not self.embedding_model_type or not kb_name:
|
||||
raise ValueError("Embedding model type and knowledge base name must be set before creating a knowledge base.")
|
||||
def _create_kb_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).filter_by(name=name).first()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
kb = await asyncio.to_thread(_get_kb_sync, kb_name)
|
||||
|
||||
if not kb:
|
||||
def _add_kb_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k)
|
||||
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
|
||||
if not kb:
|
||||
id = uuid.uuid4().int
|
||||
new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k,id=id)
|
||||
session.add(new_kb)
|
||||
session.commit()
|
||||
session.refresh(new_kb)
|
||||
return new_kb
|
||||
finally:
|
||||
session.close()
|
||||
kb = await asyncio.to_thread(_add_kb_sync)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
self.logger.info(f"Knowledge Base '{kb_name}' created.")
|
||||
return new_kb.id
|
||||
else:
|
||||
self.logger.info(f"Knowledge Base '{kb_name}' already exists.")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
return await asyncio.to_thread(_create_kb_sync)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -108,116 +98,124 @@ class RAG_Manager:
|
||||
return session.query(KnowledgeBase).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
kbs = await asyncio.to_thread(_get_all_kbs_sync)
|
||||
return kbs
|
||||
return await asyncio.to_thread(_get_all_kbs_sync)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def get_knowledge_base_by_id(self, kb_id: int):
|
||||
"""
|
||||
Retrieves a knowledge base by its ID.
|
||||
Retrieves a specific knowledge base by its ID.
|
||||
"""
|
||||
try:
|
||||
def _get_kb_sync(kb_id):
|
||||
def _get_kb_sync(kb_id_param):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
return session.query(KnowledgeBase).filter_by(id=kb_id_param).first()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
kb = await asyncio.to_thread(_get_kb_sync, kb_id)
|
||||
return kb
|
||||
return await asyncio.to_thread(_get_kb_sync, kb_id)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_id: int):
|
||||
"""
|
||||
Retrieves files associated with a specific knowledge base by querying the File table directly.
|
||||
"""
|
||||
try:
|
||||
def _get_files_sync(kb_id):
|
||||
def _get_files_sync(kb_id_param):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(File).filter_by(kb_id=kb_id).all()
|
||||
return session.query(File).filter_by(kb_id=kb_id_param).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
files = await asyncio.to_thread(_get_files_sync, kb_id)
|
||||
return files
|
||||
return await asyncio.to_thread(_get_files_sync, kb_id)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_all_files(self):
|
||||
"""
|
||||
Retrieves all files stored in the database, regardless of their association
|
||||
with any specific knowledge base.
|
||||
"""
|
||||
try:
|
||||
def _get_all_files_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(File).all()
|
||||
finally:
|
||||
session.close()
|
||||
return await asyncio.to_thread(_get_all_files_sync)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error retrieving all files: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"):
|
||||
"""
|
||||
Parses, chunks, embeds, and stores data from a given file into the RAG system.
|
||||
Associates the file with a knowledge base using kb_id in the File table.
|
||||
"""
|
||||
self.logger.info(f"Starting data storage process for file: {file_path}")
|
||||
session = SessionLocal()
|
||||
file_obj = None
|
||||
|
||||
try:
|
||||
def _get_kb_sync(name):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).filter_by(name=name).first()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
kb = await asyncio.to_thread(_get_kb_sync, kb_name)
|
||||
|
||||
# 1. 确保知识库存在或创建它
|
||||
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
|
||||
if not kb:
|
||||
self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.")
|
||||
def _add_kb_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
new_kb = KnowledgeBase(name=kb_name, description=kb_description)
|
||||
session.add(new_kb)
|
||||
session.commit()
|
||||
session.refresh(new_kb)
|
||||
return new_kb
|
||||
finally:
|
||||
session.close()
|
||||
kb = await asyncio.to_thread(_add_kb_sync)
|
||||
self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})")
|
||||
kb = KnowledgeBase(name=kb_name, description=kb_description)
|
||||
session.add(kb)
|
||||
session.commit()
|
||||
session.refresh(kb)
|
||||
self.logger.info(f"Knowledge Base '{kb_name}' created during store_data.")
|
||||
else:
|
||||
self.logger.info(f"Knowledge Base '{kb_name}' already exists.")
|
||||
|
||||
def _add_file_sync(kb_id, file_name, path, file_type):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type)
|
||||
session.add(file)
|
||||
session.commit()
|
||||
session.refresh(file)
|
||||
return file
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type)
|
||||
self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})")
|
||||
|
||||
text = await self.parser.parse(file_path)
|
||||
if not text:
|
||||
self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.")
|
||||
# You might want to delete the file_obj from the DB here if it's empty.
|
||||
session = SessionLocal()
|
||||
try:
|
||||
session.delete(file_obj)
|
||||
session.commit()
|
||||
except Exception as del_e:
|
||||
self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}")
|
||||
finally:
|
||||
session.close()
|
||||
# 2. 添加文件记录到数据库,并直接关联 kb_id
|
||||
file_name = os.path.basename(file_path)
|
||||
existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first()
|
||||
if existing_file:
|
||||
self.logger.warning(f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage.")
|
||||
return
|
||||
|
||||
file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type)
|
||||
session.add(file_obj)
|
||||
session.commit()
|
||||
session.refresh(file_obj)
|
||||
self.logger.info(f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}")
|
||||
|
||||
# 3. 解析文件内容
|
||||
text = await self.parser.parse(file_path)
|
||||
if not text:
|
||||
self.logger.warning(f"No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.")
|
||||
session.delete(file_obj)
|
||||
session.commit() # 提交删除操作
|
||||
return
|
||||
|
||||
# 4. 分块并嵌入/存储块
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
self.logger.info(f"Chunked into {len(chunks_texts)} pieces.")
|
||||
|
||||
# embed_and_store now handles both DB chunk saving and Chroma embedding
|
||||
self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
|
||||
await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts)
|
||||
|
||||
self.logger.info(f"Data storage process completed for file: {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True)
|
||||
# Consider cleaning up partially stored data if an error occurs.
|
||||
return
|
||||
if file_obj and file_obj.id:
|
||||
try:
|
||||
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id)
|
||||
except Exception as chroma_e:
|
||||
self.logger.warning(f"Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def retrieve_data(self, query: str):
|
||||
"""
|
||||
Retrieves relevant data chunks based on a given query using the configured retriever.
|
||||
"""
|
||||
self.logger.info(f"Starting data retrieval process for query: '{query}'")
|
||||
try:
|
||||
retrieved_chunks = await self.retriever.retrieve(query)
|
||||
@@ -229,60 +227,140 @@ class RAG_Manager:
|
||||
|
||||
async def delete_data_by_file_id(self, file_id: int):
|
||||
"""
|
||||
Deletes data associated with a specific file_id from both the relational DB and Chroma.
|
||||
Deletes all data associated with a specific file ID, including its chunks and vectors,
|
||||
and the file record itself.
|
||||
"""
|
||||
self.logger.info(f"Starting data deletion process for file_id: {file_id}")
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# 1. Delete from Chroma
|
||||
# 1. 从 ChromaDB 删除 embeddings
|
||||
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id)
|
||||
self.logger.info(f"Deleted embeddings from ChromaDB for file_id: {file_id}")
|
||||
|
||||
# 2. Delete chunks from relational DB
|
||||
# 2. 删除与文件关联的 chunks 记录
|
||||
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
|
||||
for chunk in chunks_to_delete:
|
||||
session.delete(chunk)
|
||||
self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.")
|
||||
self.logger.info(f"Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}")
|
||||
|
||||
# 3. Delete file entry from relational DB
|
||||
# 3. 删除文件记录本身
|
||||
file_to_delete = session.query(File).filter_by(id=file_id).first()
|
||||
if file_to_delete:
|
||||
session.delete(file_to_delete)
|
||||
self.logger.info(f"Deleted file entry {file_id} from relational DB.")
|
||||
self.logger.info(f"Deleted file record for file_id: {file_id}")
|
||||
else:
|
||||
self.logger.warning(f"File entry {file_id} not found in relational DB.")
|
||||
self.logger.warning(f"File with ID {file_id} not found in database. Skipping deletion of file record.")
|
||||
|
||||
session.commit()
|
||||
self.logger.info(f"Data deletion completed for file_id: {file_id}.")
|
||||
self.logger.info(f"Successfully completed data deletion for file_id: {file_id}")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def delete_kb_by_id(self, kb_id: int):
|
||||
"""
|
||||
Deletes a knowledge base and all associated files and chunks.
|
||||
Deletes a knowledge base and all associated files, chunks, and vectors.
|
||||
This involves querying for associated files and then deleting them.
|
||||
"""
|
||||
self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}")
|
||||
session = SessionLocal()
|
||||
session = SessionLocal() # 使用新的会话来获取 KB 和关联文件
|
||||
|
||||
try:
|
||||
# 1. Get the knowledge base
|
||||
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if not kb:
|
||||
kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if not kb_to_delete:
|
||||
self.logger.warning(f"Knowledge Base with ID {kb_id} not found.")
|
||||
return
|
||||
|
||||
# 2. Delete all files associated with this knowledge base
|
||||
files_to_delete = session.query(File).filter_by(kb_id=kb.id).all()
|
||||
for file in files_to_delete:
|
||||
await self.delete_data_by_file_id(file.id)
|
||||
# 获取所有关联的文件,通过 File 表的 kb_id 字段查询
|
||||
files_to_delete = session.query(File).filter_by(kb_id=kb_id).all()
|
||||
|
||||
# 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话
|
||||
session.close()
|
||||
|
||||
# 3. Delete the knowledge base itself
|
||||
session.delete(kb)
|
||||
# 遍历删除每个关联文件及其数据
|
||||
for file_obj in files_to_delete:
|
||||
try:
|
||||
await self.delete_data_by_file_id(file_obj.id)
|
||||
except Exception as file_del_e:
|
||||
self.logger.error(f"Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}")
|
||||
# 记录错误但继续,尝试删除其他文件
|
||||
|
||||
# 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# 重新查询,确保对象是当前会话的一部分
|
||||
kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if kb_final_delete:
|
||||
session.delete(kb_final_delete)
|
||||
session.commit()
|
||||
self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}")
|
||||
else:
|
||||
self.logger.warning(f"Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.")
|
||||
except Exception as kb_del_e:
|
||||
session.rollback()
|
||||
self.logger.error(f"Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
# 如果在最初获取 KB 或文件列表时出错
|
||||
if session.is_active:
|
||||
session.rollback()
|
||||
self.logger.error(f"Error during overall knowledge base deletion for ID {kb_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if session.is_active:
|
||||
session.close()
|
||||
|
||||
|
||||
|
||||
async def get_file_content_by_file_id(self, file_id: str) -> str:
|
||||
|
||||
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
|
||||
|
||||
_, ext = os.path.splitext(file_id.lower())
|
||||
ext = ext.lstrip('.')
|
||||
|
||||
try:
|
||||
text = file_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return "[非文本文件或编码无法识别]"
|
||||
|
||||
if ext in ["txt", "md", "csv", "log", "py", "html"]:
|
||||
return text
|
||||
else:
|
||||
return f"[未知类型: .{ext}]"
|
||||
|
||||
async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None:
|
||||
"""
|
||||
Associates a file with a knowledge base by updating the kb_id in the File table.
|
||||
"""
|
||||
self.logger.info(f"Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}")
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# 查询知识库是否存在
|
||||
kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first()
|
||||
if not kb:
|
||||
self.logger.error(f"Knowledge Base with UUID {knowledge_base_uuid} not found.")
|
||||
return
|
||||
|
||||
# 更新文件的 kb_id
|
||||
file_to_update = session.query(File).filter_by(id=file_id).first()
|
||||
if not file_to_update:
|
||||
self.logger.error(f"File with ID {file_id} not found.")
|
||||
return
|
||||
|
||||
file_to_update.kb_id = kb.id
|
||||
session.commit()
|
||||
self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}")
|
||||
self.logger.info(f"Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
|
||||
self.logger.error(f"Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
|
||||
@@ -1,64 +1,23 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
|
||||
from datetime import datetime
|
||||
# 全部迁移过去
|
||||
|
||||
Base = declarative_base()
|
||||
from pkg.entity.persistence.rag import (
|
||||
create_db_and_tables,
|
||||
SessionLocal,
|
||||
Base,
|
||||
engine,
|
||||
KnowledgeBase,
|
||||
File,
|
||||
Chunk,
|
||||
Vector,
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'kb'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
embedding_model = Column(String, default='') # 默认嵌入模型
|
||||
top_k = Column(Integer, default=5) # 默认返回的top_k数量
|
||||
files = relationship('File', back_populates='knowledge_base')
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'file'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
kb_id = Column(Integer, ForeignKey('kb.id'))
|
||||
file_name = Column(String)
|
||||
path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
file_type = Column(String)
|
||||
status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误
|
||||
knowledge_base = relationship('KnowledgeBase', back_populates='files')
|
||||
chunks = relationship('Chunk', back_populates='file')
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
__tablename__ = 'chunks'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
file_id = Column(Integer, ForeignKey('file.id'))
|
||||
text = Column(Text)
|
||||
|
||||
file = relationship('File', back_populates='chunks')
|
||||
vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one
|
||||
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
|
||||
embedding = Column(LargeBinary) # Store embeddings as binary
|
||||
|
||||
chunk = relationship('Chunk', back_populates='vector')
|
||||
|
||||
|
||||
# 数据库连接
|
||||
DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL
|
||||
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
# 创建所有表 (可以在应用启动时执行一次)
|
||||
def create_db_and_tables():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print('Database tables created/checked.')
|
||||
|
||||
|
||||
# 定义嵌入维度(请根据你实际使用的模型调整)
|
||||
EMBEDDING_DIM = 1024
|
||||
__all__ = [
|
||||
"create_db_and_tables",
|
||||
"SessionLocal",
|
||||
"Base",
|
||||
"engine",
|
||||
"KnowledgeBase",
|
||||
"File",
|
||||
"Chunk",
|
||||
"Vector",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user