fix: bugs

This commit is contained in:
Junyan Qin
2025-07-11 16:38:08 +08:00
parent 367d04d0f0
commit 9ba1ad5bd3
3 changed files with 15 additions and 58 deletions

View File

@@ -14,17 +14,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
'uuid': kb.id,
'name': kb.name,
'description': kb.description,
'embedding_model_uuid': kb.embedding_model_uuid,
'top_k': kb.top_k,
}
for kb in knowledge_bases
]
return self.success(data={'bases': bases_list})
# POST: create a new knowledge base
json_data = await quart.request.json
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
json_data.get('name'), json_data.get('description')
)
return self.success(data={'uuid': knowledge_base_uuid})
elif quart.request.method == 'POST':
json_data = await quart.request.json
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
json_data.get('name'), json_data.get('description'), json_data.get('embedding_model_uuid')
)
return self.success(data={'uuid': knowledge_base_uuid})
@self.route(
'/<knowledge_base_uuid>',

View File

@@ -5,13 +5,10 @@ import os
Base = declarative_base()
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./rag_knowledge.db")
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db')
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False}
)
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@@ -20,7 +17,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def create_db_and_tables():
"""Creates all database tables defined in the Base."""
Base.metadata.create_all(bind=engine)
print("Database tables created or already exist.")
print('Database tables created or already exist.')
class KnowledgeBase(Base):
__tablename__ = 'kb'
@@ -28,7 +26,7 @@ class KnowledgeBase(Base):
name = Column(String, index=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
embedding_model = Column(String, default='')
embedding_model_uuid = Column(String, default='')
top_k = Column(Integer, default=5)

View File

@@ -6,11 +6,7 @@ import asyncio
import uuid
from pkg.rag.knowledge.services.parser import FileParser
from pkg.rag.knowledge.services.chunker import Chunker
from pkg.rag.knowledge.services.embedder import Embedder
from pkg.rag.knowledge.services.retriever import Retriever
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from pkg.core import app
@@ -20,8 +16,6 @@ class RAGManager:
def __init__(self, ap: app.Application, logger: logging.Logger = None):
self.ap = ap
self.logger = logger or logging.getLogger(__name__)
self.embedding_model_type = None
self.embedding_model_name = None
self.chroma_manager = None
self.parser = FileParser()
self.chunker = Chunker()
@@ -32,50 +26,13 @@ class RAGManager:
"""Initializes the RAG system by creating database tables."""
await asyncio.to_thread(create_db_and_tables)
async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str):
"""
Creates and configures the specific embedding model and ChromaDB manager.
This must be called before performing embedding or retrieval operations.
"""
self.embedding_model_type = embedding_model_type
self.embedding_model_name = embedding_model_name
try:
model = EmbeddingModelFactory.create_model(
model_type=self.embedding_model_type, model_name_key=self.embedding_model_name
)
self.logger.info(
f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}"
)
except Exception as e:
self.logger.critical(
f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}"
)
raise RuntimeError('Failed to initialize RAG_Manager due to embedding model issues.')
self.chroma_manager = ChromaIndexManager(
collection_name=f'rag_collection_{self.embedding_model_name.replace("-", "_")}'
)
self.embedder = Embedder(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager,
)
self.retriever = Retriever(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager,
)
async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5):
"""
Creates a new knowledge base if it doesn't already exist.
"""
try:
if not self.embedding_model_type or not kb_name:
raise ValueError(
'Embedding model type and knowledge base name must be set before creating a knowledge base.'
)
if not kb_name:
raise ValueError('Knowledge base name must be set while creating.')
def _create_kb_sync():
session = SessionLocal()