mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
fix: bugs
This commit is contained in:
@@ -14,17 +14,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
'uuid': kb.id,
|
||||
'name': kb.name,
|
||||
'description': kb.description,
|
||||
'embedding_model_uuid': kb.embedding_model_uuid,
|
||||
'top_k': kb.top_k,
|
||||
}
|
||||
for kb in knowledge_bases
|
||||
]
|
||||
return self.success(data={'bases': bases_list})
|
||||
|
||||
# POST: create a new knowledge base
|
||||
json_data = await quart.request.json
|
||||
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
|
||||
json_data.get('name'), json_data.get('description')
|
||||
)
|
||||
return self.success(data={'uuid': knowledge_base_uuid})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
|
||||
json_data.get('name'), json_data.get('description'), json_data.get('embedding_model_uuid')
|
||||
)
|
||||
return self.success(data={'uuid': knowledge_base_uuid})
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>',
|
||||
|
||||
@@ -5,13 +5,10 @@ import os
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./rag_knowledge.db")
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db')
|
||||
|
||||
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False})
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
@@ -20,7 +17,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
def create_db_and_tables():
|
||||
"""Creates all database tables defined in the Base."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print("Database tables created or already exist.")
|
||||
print('Database tables created or already exist.')
|
||||
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'kb'
|
||||
@@ -28,7 +26,7 @@ class KnowledgeBase(Base):
|
||||
name = Column(String, index=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
embedding_model = Column(String, default='')
|
||||
embedding_model_uuid = Column(String, default='')
|
||||
top_k = Column(Integer, default=5)
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,7 @@ import asyncio
|
||||
import uuid
|
||||
from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
|
||||
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from pkg.core import app
|
||||
|
||||
|
||||
@@ -20,8 +16,6 @@ class RAGManager:
|
||||
def __init__(self, ap: app.Application, logger: logging.Logger = None):
|
||||
self.ap = ap
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.embedding_model_type = None
|
||||
self.embedding_model_name = None
|
||||
self.chroma_manager = None
|
||||
self.parser = FileParser()
|
||||
self.chunker = Chunker()
|
||||
@@ -32,50 +26,13 @@ class RAGManager:
|
||||
"""Initializes the RAG system by creating database tables."""
|
||||
await asyncio.to_thread(create_db_and_tables)
|
||||
|
||||
async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str):
|
||||
"""
|
||||
Creates and configures the specific embedding model and ChromaDB manager.
|
||||
This must be called before performing embedding or retrieval operations.
|
||||
"""
|
||||
self.embedding_model_type = embedding_model_type
|
||||
self.embedding_model_name = embedding_model_name
|
||||
|
||||
try:
|
||||
model = EmbeddingModelFactory.create_model(
|
||||
model_type=self.embedding_model_type, model_name_key=self.embedding_model_name
|
||||
)
|
||||
self.logger.info(
|
||||
f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.critical(
|
||||
f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}"
|
||||
)
|
||||
raise RuntimeError('Failed to initialize RAG_Manager due to embedding model issues.')
|
||||
|
||||
self.chroma_manager = ChromaIndexManager(
|
||||
collection_name=f'rag_collection_{self.embedding_model_name.replace("-", "_")}'
|
||||
)
|
||||
self.embedder = Embedder(
|
||||
model_type=self.embedding_model_type,
|
||||
model_name_key=self.embedding_model_name,
|
||||
chroma_manager=self.chroma_manager,
|
||||
)
|
||||
self.retriever = Retriever(
|
||||
model_type=self.embedding_model_type,
|
||||
model_name_key=self.embedding_model_name,
|
||||
chroma_manager=self.chroma_manager,
|
||||
)
|
||||
|
||||
async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5):
|
||||
"""
|
||||
Creates a new knowledge base if it doesn't already exist.
|
||||
"""
|
||||
try:
|
||||
if not self.embedding_model_type or not kb_name:
|
||||
raise ValueError(
|
||||
'Embedding model type and knowledge base name must be set before creating a knowledge base.'
|
||||
)
|
||||
if not kb_name:
|
||||
raise ValueError('Knowledge base name must be set while creating.')
|
||||
|
||||
def _create_kb_sync():
|
||||
session = SessionLocal()
|
||||
|
||||
Reference in New Issue
Block a user