diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index 292f23ce..6ded737a 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -62,8 +62,8 @@ class RAG_Manager: chroma_manager=self.chroma_manager # Inject dependency ) - - async def create_knowledge_base(self, kb_name: str, kb_description: str ,): + + async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5): """ Creates a new knowledge base with the given name and description. If a knowledge base with the same name already exists, it returns that one. @@ -82,7 +82,7 @@ class RAG_Manager: def _add_kb_sync(): session = SessionLocal() try: - new_kb = KnowledgeBase(name=kb_name, description=kb_description) + new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k) session.add(new_kb) session.commit() session.refresh(new_kb) diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py index 4ec21af3..a8c35883 100644 --- a/pkg/rag/knowledge/services/database.py +++ b/pkg/rag/knowledge/services/database.py @@ -11,7 +11,8 @@ class KnowledgeBase(Base): name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) - + embedding_model = Column(String, default="") # 默认嵌入模型 + top_k = Column(Integer, default=5) # 默认返回的top_k数量 files = relationship("File", back_populates="knowledge_base") class File(Base):