mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-09 07:16:04 +00:00
feat: modify the rag.py
This commit is contained in:
@@ -1,64 +1,23 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
|
||||
from datetime import datetime
|
||||
# 全部迁移过去
|
||||
|
||||
Base = declarative_base()
|
||||
from pkg.entity.persistence.rag import (
|
||||
create_db_and_tables,
|
||||
SessionLocal,
|
||||
Base,
|
||||
engine,
|
||||
KnowledgeBase,
|
||||
File,
|
||||
Chunk,
|
||||
Vector,
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'kb'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
embedding_model = Column(String, default='') # 默认嵌入模型
|
||||
top_k = Column(Integer, default=5) # 默认返回的top_k数量
|
||||
files = relationship('File', back_populates='knowledge_base')
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'file'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
kb_id = Column(Integer, ForeignKey('kb.id'))
|
||||
file_name = Column(String)
|
||||
path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
file_type = Column(String)
|
||||
status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误
|
||||
knowledge_base = relationship('KnowledgeBase', back_populates='files')
|
||||
chunks = relationship('Chunk', back_populates='file')
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
__tablename__ = 'chunks'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
file_id = Column(Integer, ForeignKey('file.id'))
|
||||
text = Column(Text)
|
||||
|
||||
file = relationship('File', back_populates='chunks')
|
||||
vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one
|
||||
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
|
||||
embedding = Column(LargeBinary) # Store embeddings as binary
|
||||
|
||||
chunk = relationship('Chunk', back_populates='vector')
|
||||
|
||||
|
||||
# 数据库连接
|
||||
DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL
|
||||
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
# 创建所有表 (可以在应用启动时执行一次)
|
||||
def create_db_and_tables():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print('Database tables created/checked.')
|
||||
|
||||
|
||||
# 定义嵌入维度(请根据你实际使用的模型调整)
|
||||
EMBEDDING_DIM = 1024
|
||||
__all__ = [
|
||||
"create_db_and_tables",
|
||||
"SessionLocal",
|
||||
"Base",
|
||||
"engine",
|
||||
"KnowledgeBase",
|
||||
"File",
|
||||
"Chunk",
|
||||
"Vector",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user