feat: modify the rag.py

This commit is contained in:
WangCham
2025-07-09 22:09:46 +08:00
parent cd25340826
commit ac03a2dceb
4 changed files with 338 additions and 229 deletions

View File

@@ -1,6 +1,6 @@
import quart
from .. import group
import os # 导入 os 用于文件操作
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
class KnowledgeBaseRouterGroup(group.RouterGroup):
@@ -9,8 +9,8 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
return quart.jsonify({'code': code, 'data': data or {}, 'msg': msg})
async def initialize(self) -> None:
@self.route('', methods=['POST', 'GET'])
async def _() -> str:
@self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases')
async def handle_knowledge_bases() -> str:
if quart.request.method == 'GET':
knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases()
bases_list = [
@@ -23,17 +23,17 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
]
return self.success(code=0, data={'bases': bases_list}, msg='ok')
# POST: create a new knowledge base
json_data = await quart.request.json
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
json_data.get('name'), json_data.get('description')
)
_ = knowledge_base_uuid
return self.success(code=0, data={}, msg='ok')
return self.success(code=0, data={'uuid': knowledge_base_uuid}, msg='ok')
@self.route('/<knowledge_base_uuid>', methods=['GET', 'DELETE'])
async def _(knowledge_base_uuid: str) -> str:
@self.route('/<knowledge_base_uuid>', methods=['GET', 'DELETE'], endpoint='handle_specific_knowledge_base')
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str:
if quart.request.method == 'GET':
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid)
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(int(knowledge_base_uuid))
if knowledge_base is None:
return self.http_status(404, -1, 'knowledge base not found')
@@ -48,28 +48,42 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
msg='ok',
)
elif quart.request.method == 'DELETE':
await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid)
await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid))
return self.success(code=0, msg='ok')
@self.route('/<knowledge_base_uuid>/files', methods=['GET'])
async def _(knowledge_base_uuid: str) -> str:
if quart.request.method == 'GET':
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid)
return self.success(
code=0,
data=[
{
'id': file.id,
'file_name': file.file_name,
'status': file.status,
}
for file in files
],
msg='ok',
)
# delete specific file in knowledge base
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'])
async def _(knowledge_base_uuid: str, file_id: str) -> str:
await self.ap.knowledge_base_service.delete_data_by_file_id(file_id)
@self.route('/<knowledge_base_uuid>/files', methods=['GET'], endpoint='get_knowledge_base_files')
async def get_knowledge_base_files(knowledge_base_uuid: str) -> str:
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid))
return self.success(
code=0,
data=[
{
'id': file.id,
'file_name': file.file_name,
'status': file.status,
}
for file in files
],
msg='ok',
)
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'], endpoint='delete_specific_file_in_kb')
async def delete_specific_file_in_kb(file_id: str) -> str:
await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id))
return self.success(code=0, msg='ok')
@self.route('/<knowledge_base_uuid>/files', methods=['POST'], endpoint='relate_file_with_kb')
async def relate_file_id_with_kb(knowledge_base_uuid:str,file_id: str) -> str:
if 'file' not in quart.request.files:
return self.http_status(400, -1, 'No file part in the request')
json_data = await quart.request.json
file_id = json_data.get('file_id')
if not file_id:
return self.http_status(400, -1, 'File ID is required')
# 调用服务层方法将文件与知识库关联
await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id))
return self.success(code=0, data={}, msg='ok')

View File

@@ -0,0 +1,58 @@
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, LargeBinary
from sqlalchemy.orm import declarative_base, sessionmaker
from datetime import datetime
import os
Base = declarative_base()
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./rag_knowledge.db")
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def create_db_and_tables():
"""Creates all database tables defined in the Base."""
Base.metadata.create_all(bind=engine)
print("Database tables created or already exist.")
class KnowledgeBase(Base):
__tablename__ = 'kb'
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
embedding_model = Column(String, default='')
top_k = Column(Integer, default=5)
class File(Base):
__tablename__ = 'file'
id = Column(Integer, primary_key=True, index=True)
kb_id = Column(Integer, nullable=True)
file_name = Column(String)
path = Column(String)
created_at = Column(DateTime, default=datetime.utcnow)
file_type = Column(String)
status = Column(Integer, default=0)
class Chunk(Base):
__tablename__ = 'chunks'
id = Column(Integer, primary_key=True, index=True)
file_id = Column(Integer, nullable=True)
text = Column(Text)
class Vector(Base):
__tablename__ = 'vectors'
id = Column(Integer, primary_key=True, index=True)
chunk_id = Column(Integer, nullable=True)
embedding = Column(LargeBinary)

View File

@@ -1,38 +1,42 @@
# RAG_Manager class (main class, adjust imports as needed)
from __future__ import annotations # For type hinting in Python 3.7+
# rag_manager.py
from __future__ import annotations
import logging
import os
import asyncio
import json
import uuid
from pkg.rag.knowledge.services.parser import FileParser
from pkg.rag.knowledge.services.chunker import Chunker
from pkg.rag.knowledge.services.embedder import Embedder
from pkg.rag.knowledge.services.retriever import Retriever
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from pkg.core import app # Adjust the import path as needed
from pkg.core import app
class RAG_Manager:
ap: app.Application
def __init__(self, ap: app.Application,logger: logging.Logger = None):
def __init__(self, ap: app.Application, logger: logging.Logger = None):
self.ap = ap
self.logger = logger or logging.getLogger(__name__)
self.embedding_model_type = None
self.embedding_model_name = None
self.chroma_manager = None
self.parser = None
self.chunker = None
self.parser = FileParser()
self.chunker = Chunker()
self.embedder = None
self.retriever = None
async def initialize_rag_system(self):
"""Initializes the RAG system by creating database tables."""
await asyncio.to_thread(create_db_and_tables)
async def create_specific_model(self, embedding_model_type: str,
embedding_model_name: str):
async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str):
"""
Creates and configures the specific embedding model and ChromaDB manager.
This must be called before performing embedding or retrieval operations.
"""
self.embedding_model_type = embedding_model_type
self.embedding_model_name = embedding_model_name
@@ -47,52 +51,38 @@ class RAG_Manager:
raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.")
self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}")
self.parser = FileParser()
self.chunker = Chunker()
# Pass chroma_manager to Embedder and Retriever
self.embedder = Embedder(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager # Inject dependency
)
self.retriever = Retriever(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager # Inject dependency
)
self.embedder = Embedder(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager)
self.retriever = Retriever(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager)
async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5):
"""
Creates a new knowledge base with the given name and description.
If a knowledge base with the same name already exists, it returns that one.
Creates a new knowledge base if it doesn't already exist.
"""
try:
def _get_kb_sync(name):
if not self.embedding_model_type or not kb_name:
raise ValueError("Embedding model type and knowledge base name must be set before creating a knowledge base.")
def _create_kb_sync():
session = SessionLocal()
try:
return session.query(KnowledgeBase).filter_by(name=name).first()
finally:
session.close()
kb = await asyncio.to_thread(_get_kb_sync, kb_name)
if not kb:
def _add_kb_sync():
session = SessionLocal()
try:
new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k)
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
if not kb:
id = uuid.uuid4().int
new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k,id=id)
session.add(new_kb)
session.commit()
session.refresh(new_kb)
return new_kb
finally:
session.close()
kb = await asyncio.to_thread(_add_kb_sync)
except Exception as e:
self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
raise
self.logger.info(f"Knowledge Base '{kb_name}' created.")
return new_kb.id
else:
self.logger.info(f"Knowledge Base '{kb_name}' already exists.")
except Exception as e:
session.rollback()
self.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True)
raise
finally:
session.close()
return await asyncio.to_thread(_create_kb_sync)
except Exception as e:
self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
raise
@@ -108,116 +98,124 @@ class RAG_Manager:
return session.query(KnowledgeBase).all()
finally:
session.close()
kbs = await asyncio.to_thread(_get_all_kbs_sync)
return kbs
return await asyncio.to_thread(_get_all_kbs_sync)
except Exception as e:
self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True)
return []
async def get_knowledge_base_by_id(self, kb_id: int):
"""
Retrieves a knowledge base by its ID.
Retrieves a specific knowledge base by its ID.
"""
try:
def _get_kb_sync(kb_id):
def _get_kb_sync(kb_id_param):
session = SessionLocal()
try:
return session.query(KnowledgeBase).filter_by(id=kb_id).first()
return session.query(KnowledgeBase).filter_by(id=kb_id_param).first()
finally:
session.close()
kb = await asyncio.to_thread(_get_kb_sync, kb_id)
return kb
return await asyncio.to_thread(_get_kb_sync, kb_id)
except Exception as e:
self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
return None
async def get_files_by_knowledge_base(self, kb_id: int):
"""
Retrieves files associated with a specific knowledge base by querying the File table directly.
"""
try:
def _get_files_sync(kb_id):
def _get_files_sync(kb_id_param):
session = SessionLocal()
try:
return session.query(File).filter_by(kb_id=kb_id).all()
return session.query(File).filter_by(kb_id=kb_id_param).all()
finally:
session.close()
files = await asyncio.to_thread(_get_files_sync, kb_id)
return files
return await asyncio.to_thread(_get_files_sync, kb_id)
except Exception as e:
self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True)
return []
async def get_all_files(self):
"""
Retrieves all files stored in the database, regardless of their association
with any specific knowledge base.
"""
try:
def _get_all_files_sync():
session = SessionLocal()
try:
return session.query(File).all()
finally:
session.close()
return await asyncio.to_thread(_get_all_files_sync)
except Exception as e:
self.logger.error(f"Error retrieving all files: {str(e)}", exc_info=True)
return []
async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"):
"""
Parses, chunks, embeds, and stores data from a given file into the RAG system.
Associates the file with a knowledge base using kb_id in the File table.
"""
self.logger.info(f"Starting data storage process for file: {file_path}")
session = SessionLocal()
file_obj = None
try:
def _get_kb_sync(name):
session = SessionLocal()
try:
return session.query(KnowledgeBase).filter_by(name=name).first()
finally:
session.close()
kb = await asyncio.to_thread(_get_kb_sync, kb_name)
# 1. 确保知识库存在或创建它
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
if not kb:
self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.")
def _add_kb_sync():
session = SessionLocal()
try:
new_kb = KnowledgeBase(name=kb_name, description=kb_description)
session.add(new_kb)
session.commit()
session.refresh(new_kb)
return new_kb
finally:
session.close()
kb = await asyncio.to_thread(_add_kb_sync)
self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})")
kb = KnowledgeBase(name=kb_name, description=kb_description)
session.add(kb)
session.commit()
session.refresh(kb)
self.logger.info(f"Knowledge Base '{kb_name}' created during store_data.")
else:
self.logger.info(f"Knowledge Base '{kb_name}' already exists.")
def _add_file_sync(kb_id, file_name, path, file_type):
session = SessionLocal()
try:
file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type)
session.add(file)
session.commit()
session.refresh(file)
return file
finally:
session.close()
file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type)
self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})")
text = await self.parser.parse(file_path)
if not text:
self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.")
# You might want to delete the file_obj from the DB here if it's empty.
session = SessionLocal()
try:
session.delete(file_obj)
session.commit()
except Exception as del_e:
self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}")
finally:
session.close()
# 2. 添加文件记录到数据库,并直接关联 kb_id
file_name = os.path.basename(file_path)
existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first()
if existing_file:
self.logger.warning(f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage.")
return
file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type)
session.add(file_obj)
session.commit()
session.refresh(file_obj)
self.logger.info(f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}")
# 3. 解析文件内容
text = await self.parser.parse(file_path)
if not text:
self.logger.warning(f"No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.")
session.delete(file_obj)
session.commit() # 提交删除操作
return
# 4. 分块并嵌入/存储块
chunks_texts = await self.chunker.chunk(text)
self.logger.info(f"Chunked into {len(chunks_texts)} pieces.")
# embed_and_store now handles both DB chunk saving and Chroma embedding
self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts)
self.logger.info(f"Data storage process completed for file: {file_path}")
except Exception as e:
session.rollback()
self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True)
# Consider cleaning up partially stored data if an error occurs.
return
if file_obj and file_obj.id:
try:
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id)
except Exception as chroma_e:
self.logger.warning(f"Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}")
raise
finally:
session.close()
async def retrieve_data(self, query: str):
"""
Retrieves relevant data chunks based on a given query using the configured retriever.
"""
self.logger.info(f"Starting data retrieval process for query: '{query}'")
try:
retrieved_chunks = await self.retriever.retrieve(query)
@@ -229,60 +227,140 @@ class RAG_Manager:
async def delete_data_by_file_id(self, file_id: int):
"""
Deletes data associated with a specific file_id from both the relational DB and Chroma.
Deletes all data associated with a specific file ID, including its chunks and vectors,
and the file record itself.
"""
self.logger.info(f"Starting data deletion process for file_id: {file_id}")
session = SessionLocal()
try:
# 1. Delete from Chroma
# 1. 从 ChromaDB 删除 embeddings
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id)
self.logger.info(f"Deleted embeddings from ChromaDB for file_id: {file_id}")
# 2. Delete chunks from relational DB
# 2. 删除与文件关联的 chunks 记录
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
for chunk in chunks_to_delete:
session.delete(chunk)
self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.")
self.logger.info(f"Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}")
# 3. Delete file entry from relational DB
# 3. 删除文件记录本身
file_to_delete = session.query(File).filter_by(id=file_id).first()
if file_to_delete:
session.delete(file_to_delete)
self.logger.info(f"Deleted file entry {file_id} from relational DB.")
self.logger.info(f"Deleted file record for file_id: {file_id}")
else:
self.logger.warning(f"File entry {file_id} not found in relational DB.")
self.logger.warning(f"File with ID {file_id} not found in database. Skipping deletion of file record.")
session.commit()
self.logger.info(f"Data deletion completed for file_id: {file_id}.")
self.logger.info(f"Successfully completed data deletion for file_id: {file_id}")
except Exception as e:
session.rollback()
self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True)
raise
finally:
session.close()
async def delete_kb_by_id(self, kb_id: int):
"""
Deletes a knowledge base and all associated files and chunks.
Deletes a knowledge base and all associated files, chunks, and vectors.
This involves querying for associated files and then deleting them.
"""
self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}")
session = SessionLocal()
session = SessionLocal() # 使用新的会话来获取 KB 和关联文件
try:
# 1. Get the knowledge base
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if not kb:
kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if not kb_to_delete:
self.logger.warning(f"Knowledge Base with ID {kb_id} not found.")
return
# 2. Delete all files associated with this knowledge base
files_to_delete = session.query(File).filter_by(kb_id=kb.id).all()
for file in files_to_delete:
await self.delete_data_by_file_id(file.id)
# 获取所有关联的文件,通过 File 表的 kb_id 字段查询
files_to_delete = session.query(File).filter_by(kb_id=kb_id).all()
# 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话
session.close()
# 3. Delete the knowledge base itself
session.delete(kb)
# 遍历删除每个关联文件及其数据
for file_obj in files_to_delete:
try:
await self.delete_data_by_file_id(file_obj.id)
except Exception as file_del_e:
self.logger.error(f"Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}")
# 记录错误但继续,尝试删除其他文件
# 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身
session = SessionLocal()
try:
# 重新查询,确保对象是当前会话的一部分
kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if kb_final_delete:
session.delete(kb_final_delete)
session.commit()
self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}")
else:
self.logger.warning(f"Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.")
except Exception as kb_del_e:
session.rollback()
self.logger.error(f"Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}", exc_info=True)
raise
finally:
session.close()
except Exception as e:
# 如果在最初获取 KB 或文件列表时出错
if session.is_active:
session.rollback()
self.logger.error(f"Error during overall knowledge base deletion for ID {kb_id}: {str(e)}", exc_info=True)
raise
finally:
if session.is_active:
session.close()
async def get_file_content_by_file_id(self, file_id: str) -> str:
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
_, ext = os.path.splitext(file_id.lower())
ext = ext.lstrip('.')
try:
text = file_bytes.decode("utf-8")
except UnicodeDecodeError:
return "[非文本文件或编码无法识别]"
if ext in ["txt", "md", "csv", "log", "py", "html"]:
return text
else:
return f"[未知类型: .{ext}]"
async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None:
"""
Associates a file with a knowledge base by updating the kb_id in the File table.
"""
self.logger.info(f"Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}")
session = SessionLocal()
try:
# 查询知识库是否存在
kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first()
if not kb:
self.logger.error(f"Knowledge Base with UUID {knowledge_base_uuid} not found.")
return
# 更新文件的 kb_id
file_to_update = session.query(File).filter_by(id=file_id).first()
if not file_to_update:
self.logger.error(f"File with ID {file_id} not found.")
return
file_to_update.kb_id = kb.id
session.commit()
self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}")
self.logger.info(f"Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}")
except Exception as e:
session.rollback()
self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
self.logger.error(f"Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}", exc_info=True)
finally:
session.close()

View File

@@ -1,64 +1,23 @@
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
from datetime import datetime
# 全部迁移过去
Base = declarative_base()
from pkg.entity.persistence.rag import (
create_db_and_tables,
SessionLocal,
Base,
engine,
KnowledgeBase,
File,
Chunk,
Vector,
)
class KnowledgeBase(Base):
__tablename__ = 'kb'
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
embedding_model = Column(String, default='') # 默认嵌入模型
top_k = Column(Integer, default=5) # 默认返回的top_k数量
files = relationship('File', back_populates='knowledge_base')
class File(Base):
__tablename__ = 'file'
id = Column(Integer, primary_key=True, index=True)
kb_id = Column(Integer, ForeignKey('kb.id'))
file_name = Column(String)
path = Column(String)
created_at = Column(DateTime, default=datetime.utcnow)
file_type = Column(String)
status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误
knowledge_base = relationship('KnowledgeBase', back_populates='files')
chunks = relationship('Chunk', back_populates='file')
class Chunk(Base):
__tablename__ = 'chunks'
id = Column(Integer, primary_key=True, index=True)
file_id = Column(Integer, ForeignKey('file.id'))
text = Column(Text)
file = relationship('File', back_populates='chunks')
vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one
class Vector(Base):
__tablename__ = 'vectors'
id = Column(Integer, primary_key=True, index=True)
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
embedding = Column(LargeBinary) # Store embeddings as binary
chunk = relationship('Chunk', back_populates='vector')
# 数据库连接
DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建所有表 (可以在应用启动时执行一次)
def create_db_and_tables():
Base.metadata.create_all(bind=engine)
print('Database tables created/checked.')
# 定义嵌入维度(请根据你实际使用的模型调整)
EMBEDDING_DIM = 1024
__all__ = [
"create_db_and_tables",
"SessionLocal",
"Base",
"engine",
"KnowledgeBase",
"File",
"Chunk",
"Vector",
]