From 9ba1ad5bd38e48a3315f165ce88e73c3399b56fc Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 16:38:08 +0800 Subject: [PATCH] fix: bugs --- .../http/controller/groups/knowledge/base.py | 14 +++--- pkg/entity/persistence/rag.py | 12 ++--- pkg/rag/knowledge/mgr.py | 47 +------------------ 3 files changed, 15 insertions(+), 58 deletions(-) diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index bfbbbe10..b5a48d29 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -14,17 +14,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): 'uuid': kb.id, 'name': kb.name, 'description': kb.description, + 'embedding_model_uuid': kb.embedding_model_uuid, + 'top_k': kb.top_k, } for kb in knowledge_bases ] return self.success(data={'bases': bases_list}) - # POST: create a new knowledge base - json_data = await quart.request.json - knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( - json_data.get('name'), json_data.get('description') - ) - return self.success(data={'uuid': knowledge_base_uuid}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( + json_data.get('name'), json_data.get('description'), json_data.get('embedding_model_uuid') + ) + return self.success(data={'uuid': knowledge_base_uuid}) @self.route( '/', diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py index 175720f1..1657196a 100644 --- a/pkg/entity/persistence/rag.py +++ b/pkg/entity/persistence/rag.py @@ -5,13 +5,10 @@ import os Base = declarative_base() -DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./rag_knowledge.db") +DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') -engine = create_engine( - DATABASE_URL, - connect_args={"check_same_thread": False} -) +engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False}) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -20,7 +17,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def create_db_and_tables(): """Creates all database tables defined in the Base.""" Base.metadata.create_all(bind=engine) - print("Database tables created or already exist.") + print('Database tables created or already exist.') + class KnowledgeBase(Base): __tablename__ = 'kb' @@ -28,7 +26,7 @@ class KnowledgeBase(Base): name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) - embedding_model = Column(String, default='') + embedding_model_uuid = Column(String, default='') top_k = Column(Integer, default=5) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 7d1787e0..5d4eece9 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -6,11 +6,7 @@ import asyncio import uuid from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker -from pkg.rag.knowledge.services.embedder import Embedder -from pkg.rag.knowledge.services.retriever import Retriever from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk -from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager from pkg.core import app @@ -20,8 +16,6 @@ class RAGManager: def __init__(self, ap: app.Application, logger: logging.Logger = None): self.ap = ap self.logger = logger or logging.getLogger(__name__) - self.embedding_model_type = None - self.embedding_model_name = None self.chroma_manager = None self.parser = FileParser() self.chunker = Chunker() @@ -32,50 +26,13 @@ class RAGManager: """Initializes the RAG system by creating database tables.""" await asyncio.to_thread(create_db_and_tables) - async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str): - """ - Creates and configures the specific embedding model and ChromaDB manager. - This must be called before performing embedding or retrieval operations. - """ - self.embedding_model_type = embedding_model_type - self.embedding_model_name = embedding_model_name - - try: - model = EmbeddingModelFactory.create_model( - model_type=self.embedding_model_type, model_name_key=self.embedding_model_name - ) - self.logger.info( - f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}" - ) - except Exception as e: - self.logger.critical( - f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}" - ) - raise RuntimeError('Failed to initialize RAG_Manager due to embedding model issues.') - - self.chroma_manager = ChromaIndexManager( - collection_name=f'rag_collection_{self.embedding_model_name.replace("-", "_")}' - ) - self.embedder = Embedder( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name, - chroma_manager=self.chroma_manager, - ) - self.retriever = Retriever( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name, - chroma_manager=self.chroma_manager, - ) - async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5): """ Creates a new knowledge base if it doesn't already exist. """ try: - if not self.embedding_model_type or not kb_name: - raise ValueError( - 'Embedding model type and knowledge base name must be set before creating a knowledge base.' - ) + if not kb_name: + raise ValueError('Knowledge base name must be set while creating.') def _create_kb_sync(): session = SessionLocal()