mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 16:26:02 +00:00
chore: stash
This commit is contained in:
@@ -1,149 +1,189 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
|
||||
from pkg.rag.knowledge.services.database import (
|
||||
KnowledgeBase,
|
||||
File,
|
||||
Chunk,
|
||||
)
|
||||
from pkg.core import app
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from ...entity.persistence import model as persistence_model
|
||||
from pkg.core import taskmgr
|
||||
from ...entity.persistence import rag as persistence_rag
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class RuntimeKnowledgeBase:
|
||||
ap: app.Application
|
||||
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase
|
||||
|
||||
chroma_manager: ChromaIndexManager
|
||||
|
||||
parser: FileParser
|
||||
|
||||
chunker: Chunker
|
||||
|
||||
embedder: Embedder
|
||||
|
||||
retriever: Retriever
|
||||
|
||||
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
|
||||
self.ap = ap
|
||||
self.knowledge_base_entity = knowledge_base_entity
|
||||
self.chroma_manager = ChromaIndexManager(ap=self.ap)
|
||||
self.parser = FileParser(ap=self.ap)
|
||||
self.chunker = Chunker(ap=self.ap)
|
||||
self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext):
|
||||
try:
|
||||
# set file status to processing
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='processing')
|
||||
)
|
||||
|
||||
task_context.set_current_action('Parsing file')
|
||||
# parse file
|
||||
text = await self.parser.parse(file.file_name, file.extension)
|
||||
if not text:
|
||||
raise Exception(f'No text extracted from file {file.file_name}')
|
||||
|
||||
task_context.set_current_action('Chunking file')
|
||||
# chunk file
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
if not chunks_texts:
|
||||
raise Exception(f'No chunks extracted from file {file.file_name}')
|
||||
|
||||
task_context.set_current_action('Embedding chunks')
|
||||
# embed chunks
|
||||
await self.embedder.embed_and_store(
|
||||
file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid
|
||||
)
|
||||
|
||||
# set file status to completed
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='completed')
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error storing file {file.file_id}: {e}')
|
||||
# set file status to failed
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='failed')
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def store_file(self, file_id: str) -> int:
|
||||
# pre checking
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
|
||||
raise Exception(f'File {file_id} not found')
|
||||
|
||||
file_uuid = str(uuid.uuid4())
|
||||
kb_id = self.knowledge_base_entity.uuid
|
||||
file_name = file_id
|
||||
extension = os.path.splitext(file_id)[1].lstrip('.')
|
||||
|
||||
file = persistence_rag.File(
|
||||
uuid=file_uuid,
|
||||
kb_id=kb_id,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
status='pending',
|
||||
)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict()))
|
||||
|
||||
# run background task asynchronously
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._store_file_task(file, task_context=ctx),
|
||||
kind='knowledge-operation',
|
||||
name=f'knowledge-store-file-{file_id}',
|
||||
label=f'Store file {file_id}',
|
||||
context=ctx,
|
||||
)
|
||||
return wrapper.id
|
||||
|
||||
async def dispose(self):
|
||||
pass
|
||||
|
||||
|
||||
class RAGManager:
|
||||
ap: app.Application
|
||||
|
||||
knowledge_bases: list[RuntimeKnowledgeBase]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.chroma_manager = ChromaIndexManager()
|
||||
self.parser = FileParser()
|
||||
self.chunker = Chunker()
|
||||
self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
self.knowledge_bases = []
|
||||
|
||||
async def initialize_rag_system(self):
|
||||
"""Initializes the RAG system by creating database tables."""
|
||||
await asyncio.to_thread(create_db_and_tables)
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def create_knowledge_base(
|
||||
self, kb_name: str, kb_description: str, embedding_model_uuid: str = '', top_k: int = 5
|
||||
):
|
||||
"""
|
||||
Creates a new knowledge base if it doesn't already exist.
|
||||
"""
|
||||
try:
|
||||
if not kb_name:
|
||||
raise ValueError('Knowledge base name must be set while creating.')
|
||||
async def load_knowledge_bases_from_db(self):
|
||||
self.ap.logger.info('Loading knowledge bases from db...')
|
||||
|
||||
def _create_kb_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
|
||||
if not kb:
|
||||
id = str(uuid.uuid4())
|
||||
new_kb = KnowledgeBase(
|
||||
name=kb_name,
|
||||
description=kb_description,
|
||||
embedding_model_uuid=embedding_model_uuid,
|
||||
top_k=top_k,
|
||||
id=id,
|
||||
)
|
||||
session.add(new_kb)
|
||||
session.commit()
|
||||
session.refresh(new_kb)
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_name}' created.")
|
||||
print(embedding_model_uuid)
|
||||
return new_kb.id
|
||||
else:
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
self.knowledge_bases = []
|
||||
|
||||
return await asyncio.to_thread(_create_kb_sync)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||
|
||||
async def get_all_knowledge_bases(self):
|
||||
"""
|
||||
Retrieves all knowledge bases from the database.
|
||||
"""
|
||||
try:
|
||||
knowledge_bases = result.all()
|
||||
|
||||
def _get_all_kbs_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).all()
|
||||
finally:
|
||||
session.close()
|
||||
for knowledge_base in knowledge_bases:
|
||||
try:
|
||||
await self.load_knowledge_base(knowledge_base)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}'
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(_get_all_kbs_sync)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True)
|
||||
return []
|
||||
async def load_knowledge_base(
|
||||
self,
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict,
|
||||
) -> RuntimeKnowledgeBase:
|
||||
if isinstance(knowledge_base_entity, sqlalchemy.Row):
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping)
|
||||
elif isinstance(knowledge_base_entity, dict):
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity)
|
||||
|
||||
async def get_knowledge_base_by_id(self, kb_id: str):
|
||||
"""
|
||||
Retrieves a specific knowledge base by its ID.
|
||||
"""
|
||||
try:
|
||||
runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity)
|
||||
|
||||
def _get_kb_sync(kb_id_param):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).filter_by(id=kb_id_param).first()
|
||||
finally:
|
||||
session.close()
|
||||
await runtime_knowledge_base.initialize()
|
||||
|
||||
return await asyncio.to_thread(_get_kb_sync, kb_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True)
|
||||
return None
|
||||
self.knowledge_bases.append(runtime_knowledge_base)
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_id: str):
|
||||
"""
|
||||
Retrieves files associated with a specific knowledge base by querying the File table directly.
|
||||
"""
|
||||
try:
|
||||
return runtime_knowledge_base
|
||||
|
||||
def _get_files_sync(kb_id_param):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(File).filter_by(kb_id=kb_id_param).all()
|
||||
finally:
|
||||
session.close()
|
||||
async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> RuntimeKnowledgeBase | None:
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.knowledge_base_entity.uuid == kb_uuid:
|
||||
return kb
|
||||
return None
|
||||
|
||||
return await asyncio.to_thread(_get_files_sync, kb_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_all_files(self):
|
||||
"""
|
||||
Retrieves all files stored in the database, regardless of their association
|
||||
with any specific knowledge base.
|
||||
"""
|
||||
try:
|
||||
|
||||
def _get_all_files_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(File).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
return await asyncio.to_thread(_get_all_files_sync)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True)
|
||||
return []
|
||||
async def remove_knowledge_base(self, kb_uuid: str):
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.knowledge_base_entity.uuid == kb_uuid:
|
||||
await kb.dispose()
|
||||
self.knowledge_bases.remove(kb)
|
||||
return
|
||||
|
||||
async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None):
|
||||
"""
|
||||
@@ -220,7 +260,8 @@ class RAGManager:
|
||||
await self.ap.storage_mgr.storage_provider.delete(file_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Error deleting file from storage for file_id {file_id}: {str(e)}', exc_info=True
|
||||
f'Error deleting file from storage for file_id {file_id}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
self.ap.logger.info(f'Deleted file record for file_id: {file_id}')
|
||||
else:
|
||||
@@ -273,7 +314,10 @@ class RAGManager:
|
||||
)
|
||||
except Exception as kb_del_e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True)
|
||||
self.ap.logger.error(
|
||||
f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}',
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
@@ -283,7 +327,8 @@ class RAGManager:
|
||||
if session.is_active:
|
||||
session.rollback()
|
||||
self.ap.logger.error(
|
||||
f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True
|
||||
f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user