diff --git a/pkg/api/http/controller/group.py b/pkg/api/http/controller/group.py index ce366539..3f34d79b 100644 --- a/pkg/api/http/controller/group.py +++ b/pkg/api/http/controller/group.py @@ -14,7 +14,7 @@ preregistered_groups: list[type[RouterGroup]] = [] """RouterGroup 的预注册列表""" -def group_class(name: str, path: str) -> None: +def group_class(name: str, path: str) -> typing.Callable[[typing.Type[RouterGroup]], typing.Type[RouterGroup]]: """注册一个 RouterGroup""" def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]: @@ -120,6 +120,6 @@ class RouterGroup(abc.ABC): } ) - def http_status(self, status: int, code: int, msg: str) -> quart.Response: + def http_status(self, status: int, code: int, msg: str) -> typing.Tuple[quart.Response, int]: """返回一个指定状态码的响应""" - return self.fail(code, msg), status + return (self.fail(code, msg), status) diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index c2208f6f..866b4af2 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -5,77 +5,49 @@ from ... import group @group.group_class('knowledge_base', '/api/v1/knowledge/bases') class KnowledgeBaseRouterGroup(group.RouterGroup): async def initialize(self) -> None: - @self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases') - async def handle_knowledge_bases() -> str: + @self.route('', methods=['POST', 'GET']) + async def handle_knowledge_bases() -> quart.Response: if quart.request.method == 'GET': - knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases() - bases_list = [ - { - 'uuid': kb.id, - 'name': kb.name, - 'description': kb.description, - 'embedding_model_uuid': kb.embedding_model_uuid, - 'top_k': kb.top_k, - } - for kb in knowledge_bases - ] - return self.success(data={'bases': bases_list}) + knowledge_bases = await self.ap.knowledge_service.get_knowledge_bases() + return self.success(data={'bases': knowledge_bases}) elif quart.request.method == 'POST': json_data = await quart.request.json - knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( - json_data.get('name'), - json_data.get('description'), - json_data.get('embedding_model_uuid'), - json_data.get('top_k',5), - ) + knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data) return self.success(data={'uuid': knowledge_base_uuid}) + return self.http_status(405, -1, 'Method not allowed') + @self.route( '/', methods=['GET', 'DELETE'], - endpoint='handle_specific_knowledge_base', ) - async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str: + async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> quart.Response: if quart.request.method == 'GET': - knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid) + knowledge_base = await self.ap.knowledge_service.get_knowledge_base(knowledge_base_uuid) if knowledge_base is None: return self.http_status(404, -1, 'knowledge base not found') return self.success( data={ - 'base': { - 'name': knowledge_base.name, - 'description': knowledge_base.description, - 'uuid': knowledge_base.id, - 'embedding_model_uuid': knowledge_base.embedding_model_uuid, - 'top_k': knowledge_base.top_k, - }, + 'base': knowledge_base, } ) elif quart.request.method == 'DELETE': - await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid) + await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid) return self.success({}) @self.route( '//files', methods=['GET', 'POST'], - endpoint='get_knowledge_base_files', ) async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) + files = await self.ap.knowledge_service.get_files_by_knowledge_base(knowledge_base_uuid) return self.success( data={ - 'files': [ - { - 'id': file.id, - 'file_name': file.file_name, - 'status': file.status, - } - for file in files - ], + 'files': files, } ) @@ -86,14 +58,17 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): return self.http_status(400, -1, 'File ID is required') # 调用服务层方法将文件与知识库关联 - await self.ap.knowledge_base_service.relate_file_id_with_kb(knowledge_base_uuid, file_id) - return self.success({}) + task_id = await self.ap.knowledge_service.store_file(knowledge_base_uuid, file_id) + return self.success( + { + 'task_id': task_id, + } + ) @self.route( '//files/', methods=['DELETE'], - endpoint='delete_specific_file_in_kb', ) - async def delete_specific_file_in_kb(file_id: str,knowledge_base_uuid: str) -> str: - await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) + async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str: + await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id) return self.success({}) diff --git a/pkg/api/http/service/knowledge.py b/pkg/api/http/service/knowledge.py new file mode 100644 index 00000000..5d702ba4 --- /dev/null +++ b/pkg/api/http/service/knowledge.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import uuid +import sqlalchemy + +from ....core import app +from ....entity.persistence import rag as persistence_rag + + +class KnowledgeService: + """知识库服务""" + + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_knowledge_bases(self) -> list[dict]: + """获取所有知识库""" + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase)) + knowledge_bases = result.all() + return [ + self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base) + for knowledge_base in knowledge_bases + ] + + async def get_knowledge_base(self, kb_uuid: str) -> dict | None: + """获取知识库""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid) + ) + knowledge_base = result.first() + if knowledge_base is None: + return None + return self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base) + + async def create_knowledge_base(self, kb_data: dict) -> str: + """创建知识库""" + kb_data['uuid'] = str(uuid.uuid4()) + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data)) + + kb = await self.get_knowledge_base(kb_data['uuid']) + + await self.ap.rag_mgr.load_knowledge_base(kb) + + return kb_data['uuid'] + + async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None: + """更新知识库""" + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.KnowledgeBase) + .values(kb_data) + .where(persistence_rag.KnowledgeBase.uuid == kb_uuid) + ) + await self.ap.rag_mgr.remove_knowledge_base(kb_uuid) + + kb = await self.get_knowledge_base(kb_uuid) + + await self.ap.rag_mgr.load_knowledge_base(kb) + + async def store_file(self, kb_uuid: str, file_id: str) -> int: + """存储文件""" + # await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(kb_id=kb_uuid, file_id=file_id)) + # await self.ap.rag_mgr.store_file(file_id) + runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + if runtime_kb is None: + raise Exception('Knowledge base not found') + return await runtime_kb.store_file(file_id) + + async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]: + """获取知识库文件""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid) + ) + files = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_rag.File, file) for file in files] + + async def delete_file(self, kb_uuid: str, file_id: str) -> None: + """删除文件""" + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id) + ) + # TODO: remove from memory + + async def delete_knowledge_base(self, kb_uuid: str) -> None: + """删除知识库""" + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid) + ) + # TODO: remove from memory diff --git a/pkg/core/app.py b/pkg/core/app.py index 11d25826..ca2c5c1c 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -22,6 +22,7 @@ from ..api.http.service import user as user_service from ..api.http.service import model as model_service from ..api.http.service import pipeline as pipeline_service from ..api.http.service import bot as bot_service +from ..api.http.service import knowledge as knowledge_service from ..discover import engine as discover_engine from ..storage import mgr as storagemgr from ..utils import logcache @@ -48,6 +49,8 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None + rag_mgr: rag_mgr.RAGManager = None + # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None @@ -112,7 +115,7 @@ class Application: bot_service: bot_service.BotService = None - knowledge_base_service: rag_mgr.RAGManager = None + knowledge_service: knowledge_service.KnowledgeService = None def __init__(self): pass diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index ac76c331..18240962 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -17,6 +17,7 @@ from ...api.http.service import user as user_service from ...api.http.service import model as model_service from ...api.http.service import pipeline as pipeline_service from ...api.http.service import bot as bot_service +from ...api.http.service import knowledge as knowledge_service from ...discover import engine as discover_engine from ...storage import mgr as storagemgr from ...utils import logcache @@ -89,6 +90,10 @@ class BuildAppStage(stage.BootingStage): await pipeline_mgr.initialize() ap.pipeline_mgr = pipeline_mgr + rag_mgr_inst = rag_mgr.RAGManager(ap) + await rag_mgr_inst.initialize_rag_system() + ap.rag_mgr = rag_mgr_inst + http_ctrl = http_controller.HTTPController(ap) await http_ctrl.initialize() ap.http_ctrl = http_ctrl @@ -102,15 +107,14 @@ class BuildAppStage(stage.BootingStage): embedding_models_service_inst = model_service.EmbeddingModelsService(ap) ap.embedding_models_service = embedding_models_service_inst - knowledge_base_service_inst = rag_mgr.RAGManager(ap) - await knowledge_base_service_inst.initialize_rag_system() - ap.knowledge_base_service = knowledge_base_service_inst - pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst bot_service_inst = bot_service.BotService(ap) ap.bot_service = bot_service_inst + knowledge_service_inst = knowledge_service.KnowledgeService(ap) + ap.knowledge_service = knowledge_service_inst + ctrl = controller.Controller(ap) ap.ctrl = ctrl diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py index 9ca84741..0ff93d28 100644 --- a/pkg/entity/persistence/rag.py +++ b/pkg/entity/persistence/rag.py @@ -1,51 +1,50 @@ -from sqlalchemy import create_engine, Column, String, Text, DateTime, LargeBinary, Integer -from sqlalchemy.orm import declarative_base, sessionmaker -from datetime import datetime -import os -import uuid +import sqlalchemy +from .base import Base -Base = declarative_base() -DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') -print("Using database URL:", DATABASE_URL) +# Base = declarative_base() +# DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') +# print("Using database URL:", DATABASE_URL) -engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False}) +# engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False}) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -def create_db_and_tables(): - """Creates all database tables defined in the Base.""" - Base.metadata.create_all(bind=engine) - print('Database tables created or already exist.') +# def create_db_and_tables(): +# """Creates all database tables defined in the Base.""" +# Base.metadata.create_all(bind=engine) +# print('Database tables created or already exist.') class KnowledgeBase(Base): - __tablename__ = 'kb' - id = Column(String, primary_key=True, index=True) - name = Column(String, index=True) - description = Column(Text) - created_at = Column(DateTime, default=datetime.utcnow) - embedding_model_uuid = Column(String, default='') - top_k = Column(Integer, default=5) + __tablename__ = 'knowledge_bases' + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String, index=True) + description = sqlalchemy.Column(sqlalchemy.Text) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now()) + embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='') + top_k = sqlalchemy.Column(sqlalchemy.Integer, default=5) + class File(Base): - __tablename__ = 'file' - id = Column(String, primary_key=True, index=True) - kb_id = Column(String, nullable=True) - file_name = Column(String) - path = Column(String) - created_at = Column(DateTime, default=datetime.utcnow) - file_type = Column(String) - status = Column(Integer, default=0) # 0: uploaded and processing, 1: completed, 2: failed + __tablename__ = 'knowledge_base_files' + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + kb_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + file_name = sqlalchemy.Column(sqlalchemy.String) + extension = sqlalchemy.Column(sqlalchemy.String) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now()) + status = sqlalchemy.Column(sqlalchemy.String, default='pending') # pending, processing, completed, failed + class Chunk(Base): - __tablename__ = 'chunks' - id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) - file_id = Column(String, nullable=True) - text = Column(Text) + __tablename__ = 'knowledge_base_chunks' + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + file_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + text = sqlalchemy.Column(sqlalchemy.Text) -class Vector(Base): - __tablename__ = 'vectors' - id = Column(String, primary_key=True, index=True) - chunk_id = Column(String, nullable=True) - embedding = Column(LargeBinary) + +# class Vector(Base): +# __tablename__ = 'knowledge_base_vectors' +# uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) +# chunk_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) +# embedding = sqlalchemy.Column(sqlalchemy.LargeBinary) diff --git a/pkg/platform/logger.py b/pkg/platform/logger.py index 340baa07..a2ea2e25 100644 --- a/pkg/platform/logger.py +++ b/pkg/platform/logger.py @@ -119,7 +119,7 @@ class EventLogger: async def _truncate_logs(self): if len(self.logs) > MAX_LOG_COUNT: for i in range(DELETE_COUNT_PER_TIME): - for image_key in self.logs[i].images: + for image_key in self.logs[i].images: # type: ignore await self.ap.storage_mgr.storage_provider.delete(image_key) self.logs = self.logs[DELETE_COUNT_PER_TIME:] diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 58de52f5..28b8d666 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -1,149 +1,189 @@ - from __future__ import annotations import os import asyncio +import traceback import uuid from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker -from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk +from pkg.rag.knowledge.services.database import ( + KnowledgeBase, + File, + Chunk, +) from pkg.core import app from pkg.rag.knowledge.services.embedder import Embedder from pkg.rag.knowledge.services.retriever import Retriever from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager -from ...entity.persistence import model as persistence_model +from pkg.core import taskmgr +from ...entity.persistence import rag as persistence_rag import sqlalchemy +class RuntimeKnowledgeBase: + ap: app.Application + + knowledge_base_entity: persistence_rag.KnowledgeBase + + chroma_manager: ChromaIndexManager + + parser: FileParser + + chunker: Chunker + + embedder: Embedder + + retriever: Retriever + + def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase): + self.ap = ap + self.knowledge_base_entity = knowledge_base_entity + self.chroma_manager = ChromaIndexManager(ap=self.ap) + self.parser = FileParser(ap=self.ap) + self.chunker = Chunker(ap=self.ap) + self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager) + self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager) + + async def initialize(self): + pass + + async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext): + try: + # set file status to processing + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='processing') + ) + + task_context.set_current_action('Parsing file') + # parse file + text = await self.parser.parse(file.file_name, file.extension) + if not text: + raise Exception(f'No text extracted from file {file.file_name}') + + task_context.set_current_action('Chunking file') + # chunk file + chunks_texts = await self.chunker.chunk(text) + if not chunks_texts: + raise Exception(f'No chunks extracted from file {file.file_name}') + + task_context.set_current_action('Embedding chunks') + # embed chunks + await self.embedder.embed_and_store( + file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid + ) + + # set file status to completed + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='completed') + ) + + except Exception as e: + self.ap.logger.error(f'Error storing file {file.file_id}: {e}') + # set file status to failed + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='failed') + ) + + raise + + async def store_file(self, file_id: str) -> int: + # pre checking + if not await self.ap.storage_mgr.storage_provider.exists(file_id): + raise Exception(f'File {file_id} not found') + + file_uuid = str(uuid.uuid4()) + kb_id = self.knowledge_base_entity.uuid + file_name = file_id + extension = os.path.splitext(file_id)[1].lstrip('.') + + file = persistence_rag.File( + uuid=file_uuid, + kb_id=kb_id, + file_name=file_name, + extension=extension, + status='pending', + ) + + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict())) + + # run background task asynchronously + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._store_file_task(file, task_context=ctx), + kind='knowledge-operation', + name=f'knowledge-store-file-{file_id}', + label=f'Store file {file_id}', + context=ctx, + ) + return wrapper.id + + async def dispose(self): + pass + + class RAGManager: ap: app.Application + knowledge_bases: list[RuntimeKnowledgeBase] + def __init__(self, ap: app.Application): self.ap = ap - self.chroma_manager = ChromaIndexManager() - self.parser = FileParser() - self.chunker = Chunker() - self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager) - self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager) + self.knowledge_bases = [] - async def initialize_rag_system(self): - """Initializes the RAG system by creating database tables.""" - await asyncio.to_thread(create_db_and_tables) + async def initialize(self): + pass - async def create_knowledge_base( - self, kb_name: str, kb_description: str, embedding_model_uuid: str = '', top_k: int = 5 - ): - """ - Creates a new knowledge base if it doesn't already exist. - """ - try: - if not kb_name: - raise ValueError('Knowledge base name must be set while creating.') + async def load_knowledge_bases_from_db(self): + self.ap.logger.info('Loading knowledge bases from db...') - def _create_kb_sync(): - session = SessionLocal() - try: - kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() - if not kb: - id = str(uuid.uuid4()) - new_kb = KnowledgeBase( - name=kb_name, - description=kb_description, - embedding_model_uuid=embedding_model_uuid, - top_k=top_k, - id=id, - ) - session.add(new_kb) - session.commit() - session.refresh(new_kb) - self.ap.logger.info(f"Knowledge Base '{kb_name}' created.") - print(embedding_model_uuid) - return new_kb.id - else: - self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.") - except Exception as e: - session.rollback() - self.ap.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True) - raise - finally: - session.close() + self.knowledge_bases = [] - return await asyncio.to_thread(_create_kb_sync) - except Exception as e: - self.ap.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) - raise + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase)) - async def get_all_knowledge_bases(self): - """ - Retrieves all knowledge bases from the database. - """ - try: + knowledge_bases = result.all() - def _get_all_kbs_sync(): - session = SessionLocal() - try: - return session.query(KnowledgeBase).all() - finally: - session.close() + for knowledge_base in knowledge_bases: + try: + await self.load_knowledge_base(knowledge_base) + except Exception as e: + self.ap.logger.error( + f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}' + ) - return await asyncio.to_thread(_get_all_kbs_sync) - except Exception as e: - self.ap.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True) - return [] + async def load_knowledge_base( + self, + knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict, + ) -> RuntimeKnowledgeBase: + if isinstance(knowledge_base_entity, sqlalchemy.Row): + knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping) + elif isinstance(knowledge_base_entity, dict): + knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity) - async def get_knowledge_base_by_id(self, kb_id: str): - """ - Retrieves a specific knowledge base by its ID. - """ - try: + runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity) - def _get_kb_sync(kb_id_param): - session = SessionLocal() - try: - return session.query(KnowledgeBase).filter_by(id=kb_id_param).first() - finally: - session.close() + await runtime_knowledge_base.initialize() - return await asyncio.to_thread(_get_kb_sync, kb_id) - except Exception as e: - self.ap.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True) - return None + self.knowledge_bases.append(runtime_knowledge_base) - async def get_files_by_knowledge_base(self, kb_id: str): - """ - Retrieves files associated with a specific knowledge base by querying the File table directly. - """ - try: + return runtime_knowledge_base - def _get_files_sync(kb_id_param): - session = SessionLocal() - try: - return session.query(File).filter_by(kb_id=kb_id_param).all() - finally: - session.close() + async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> RuntimeKnowledgeBase | None: + for kb in self.knowledge_bases: + if kb.knowledge_base_entity.uuid == kb_uuid: + return kb + return None - return await asyncio.to_thread(_get_files_sync, kb_id) - except Exception as e: - self.ap.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True) - return [] - - async def get_all_files(self): - """ - Retrieves all files stored in the database, regardless of their association - with any specific knowledge base. - """ - try: - - def _get_all_files_sync(): - session = SessionLocal() - try: - return session.query(File).all() - finally: - session.close() - - return await asyncio.to_thread(_get_all_files_sync) - except Exception as e: - self.ap.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True) - return [] + async def remove_knowledge_base(self, kb_uuid: str): + for kb in self.knowledge_bases: + if kb.knowledge_base_entity.uuid == kb_uuid: + await kb.dispose() + self.knowledge_bases.remove(kb) + return async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None): """ @@ -220,7 +260,8 @@ class RAGManager: await self.ap.storage_mgr.storage_provider.delete(file_id) except Exception as e: self.ap.logger.error( - f'Error deleting file from storage for file_id {file_id}: {str(e)}', exc_info=True + f'Error deleting file from storage for file_id {file_id}: {str(e)}', + exc_info=True, ) self.ap.logger.info(f'Deleted file record for file_id: {file_id}') else: @@ -273,7 +314,10 @@ class RAGManager: ) except Exception as kb_del_e: session.rollback() - self.ap.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True) + self.ap.logger.error( + f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', + exc_info=True, + ) raise finally: session.close() @@ -283,7 +327,8 @@ class RAGManager: if session.is_active: session.rollback() self.ap.logger.error( - f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True + f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', + exc_info=True, ) raise finally: diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py index f8020cdb..17757b47 100644 --- a/pkg/rag/knowledge/services/chroma_manager.py +++ b/pkg/rag/knowledge/services/chroma_manager.py @@ -1,43 +1,43 @@ - import numpy as np import logging from chromadb import PersistentClient -import os +from pkg.core import app logger = logging.getLogger(__name__) + class ChromaIndexManager: - def __init__(self, collection_name: str = "default_collection"): - self.logger = logging.getLogger(self.__class__.__name__) - chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma")) - os.makedirs(chroma_data_path, exist_ok=True) + def __init__(self, ap: app.Application, collection_name: str = 'default_collection'): + self.ap = ap + chroma_data_path = './data/chroma' self.client = PersistentClient(path=chroma_data_path) self._collection_name = collection_name self._collection = None - self.logger.info(f"ChromaIndexManager initialized. Collection name: {self._collection_name}") + self.ap.logger.info(f'ChromaIndexManager initialized. Collection name: {self._collection_name}') @property def collection(self): if self._collection is None: self._collection = self.client.get_or_create_collection(name=self._collection_name) - self.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.") + self.ap.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.") return self._collection - def add_embeddings_sync(self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]): - if embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents): - raise ValueError("Embedding, file_id, chunk_id, and document count mismatch.") + def add_embeddings_sync( + self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str] + ): + if ( + embeddings.shape[0] != len(chunk_ids) + or embeddings.shape[0] != len(file_ids) + or embeddings.shape[0] != len(documents) + ): + raise ValueError('Embedding, file_id, chunk_id, and document count mismatch.') - chroma_ids = [f"{file_id}_{chunk_id}" for file_id, chunk_id in zip(file_ids, chunk_ids)] - metadatas = [{"file_id": fid, "chunk_id": cid} for fid, cid in zip(file_ids, chunk_ids)] + chroma_ids = [f'{file_id}_{chunk_id}' for file_id, chunk_id in zip(file_ids, chunk_ids)] + metadatas = [{'file_id': fid, 'chunk_id': cid} for fid, cid in zip(file_ids, chunk_ids)] self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") - self.collection.add( - embeddings=embeddings.tolist(), - ids=chroma_ids, - metadatas=metadatas, - documents=documents - ) + self.collection.add(embeddings=embeddings.tolist(), ids=chroma_ids, metadatas=metadatas, documents=documents) self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") def search_sync(self, query_embedding: np.ndarray, k: int = 5): @@ -54,12 +54,14 @@ class ChromaIndexManager: query_embeddings=query_embedding.tolist(), n_results=k, # REMOVE 'ids' from the include list. It's returned by default. - include=["metadatas", "distances", "documents"] + include=['metadatas', 'distances', 'documents'], ) - self.logger.debug(f"Chroma search returned {len(results.get('ids', [[]])[0])} results.") + self.logger.debug(f'Chroma search returned {len(results.get("ids", [[]])[0])} results.') return results def delete_by_file_id_sync(self, file_id: int): - self.logger.info(f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'.") - self.collection.delete(where={"file_id": file_id}) - self.logger.info(f"Deleted embeddings for file_id: {file_id} from Chroma.") \ No newline at end of file + self.logger.info( + f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'." + ) + self.collection.delete(where={'file_id': file_id}) + self.logger.info(f'Deleted embeddings for file_id: {file_id} from Chroma.') diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py index 2db7c104..93b10a55 100644 --- a/pkg/rag/knowledge/services/chunker.py +++ b/pkg/rag/knowledge/services/chunker.py @@ -1,21 +1,26 @@ # services/chunker.py import logging from typing import List -from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync +from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync +from pkg.core import app logger = logging.getLogger(__name__) + class Chunker(BaseService): """ A class for splitting long texts into smaller, overlapping chunks. """ - def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50): - super().__init__() # Initialize BaseService - self.logger = logging.getLogger(self.__class__.__name__) + + def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50): + super().__init__(ap) # Initialize BaseService + self.ap = ap self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap if self.chunk_overlap >= self.chunk_size: - self.logger.warning("Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.") + self.logger.warning( + 'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.' + ) def _split_text_sync(self, text: str) -> List[str]: """ @@ -27,18 +32,19 @@ class Chunker(BaseService): # words = text.split() # chunks = [] # current_chunk = [] - + # for word in words: # current_chunk.append(word) # if len(current_chunk) > self.chunk_size: # chunks.append(" ".join(current_chunk[:self.chunk_size])) # current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] - + # if current_chunk: # chunks.append(" ".join(current_chunk)) - + # A more robust chunking strategy (e.g., using recursive character text splitter) from langchain.text_splitter import RecursiveCharacterTextSplitter + text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, @@ -51,8 +57,8 @@ class Chunker(BaseService): """ Asynchronously chunks a given text into smaller pieces. """ - self.logger.info(f"Chunking text (length: {len(text)})...") + self.ap.logger.info(f'Chunking text (length: {len(text)})...') # Run the synchronous splitting logic in a separate thread chunks = await self._run_sync(self._split_text_sync, text) - self.logger.info(f"Text chunked into {len(chunks)} pieces.") - return chunks \ No newline at end of file + self.ap.logger.info(f'Text chunked into {len(chunks)} pieces.') + return chunks diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 4da7e82a..34165eab 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -7,16 +7,10 @@ from sqlalchemy.orm import Session from pkg.rag.knowledge.services.base_service import BaseService from pkg.rag.knowledge.services.database import Chunk, SessionLocal from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager -from sqlalchemy.orm import declarative_base, sessionmaker from ....core import app -from ....entity.persistence import model as persistence_model -import sqlalchemy from ....provider.modelmgr.requester import RuntimeEmbeddingModel -base = declarative_base() -logger = logging.getLogger(__name__) - class Embedder(BaseService): def __init__(self, ap: app.Application, chroma_manager: ChromaIndexManager = None) -> None: super().__init__() @@ -30,61 +24,66 @@ class Embedder(BaseService): This function assumes it's called within a context where the session will be committed/rolled back and closed by the caller. """ - self.logger.debug(f"Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).") + self.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).') chunk_objects = [] for text in chunks_texts: chunk = Chunk(file_id=file_id, text=text) session.add(chunk) chunk_objects.append(chunk) - session.flush() # This populates the .id attribute for each new chunk object - self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.") + session.flush() # This populates the .id attribute for each new chunk object + self.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.') return chunk_objects - async def embed_and_store(self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel) -> List[Chunk]: + async def embed_and_store( + self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel + ) -> List[Chunk]: if not embedding_model: - raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.") + raise RuntimeError('Embedding model not loaded. Please check Embedder initialization.') - session = SessionLocal() # Start a session that will live for the whole operation + session = SessionLocal() # Start a session that will live for the whole operation chunk_objects = [] try: # 1. Save chunks to the relational database first to get their IDs # We call _db_save_chunks_sync directly without _run_sync's session management # because we manage the session here across multiple async calls. chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks) - session.commit() # Commit chunks to make their IDs permanent and accessible + session.commit() # Commit chunks to make their IDs permanent and accessible if not chunk_objects: - self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.") + self.logger.warning( + f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.' + ) return [] - + # get the embeddings for the chunks - embeddings = [] - i = 0 - while i Any: """ @@ -35,138 +36,160 @@ class FileParser: try: return await asyncio.to_thread(sync_func, *args, **kwargs) except Exception as e: - self.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}') + self.ap.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}') raise - async def parse(self, file_path: str) -> Union[str, None]: + async def parse(self, file_name: str, extension: str) -> Union[str, None]: """ Parses the file based on its extension and returns the extracted text content. This is the main asynchronous entry point for parsing. Args: - file_path (str): The path to the file to be parsed. + file_name (str): The name of the file to be parsed, get from ap.storage_mgr Returns: Union[str, None]: The extracted text content as a single string, or None if parsing fails. """ - if not file_path or not os.path.exists(file_path): - self.logger.error(f'Invalid file path provided: {file_path}') - return None - file_extension = file_path.split('.')[-1].lower() + file_extension = extension.lower() parser_method = getattr(self, f'_parse_{file_extension}', None) if parser_method is None: - self.logger.error(f'Unsupported file format: {file_extension} for file {file_path}') + self.ap.logger.error(f'Unsupported file format: {file_extension} for file {file_name}') return None try: # Pass file_path to the specific parser methods - return await parser_method(file_path) + return await parser_method(file_name) except Exception as e: - self.logger.error(f'Failed to parse {file_extension} file {file_path}: {e}') + self.ap.logger.error(f'Failed to parse {file_extension} file {file_name}: {e}') return None # --- Helper for reading files with encoding detection --- - async def _read_file_content(self, file_path: str, mode: str = 'r') -> Union[str, bytes]: + async def _read_file_content(self, file_name: str) -> Union[str, bytes]: """ Reads a file with automatic encoding detection, ensuring the synchronous file read operation runs in a separate thread. """ - def _read_sync(): - with open(file_path, 'rb') as file: - raw_data = file.read() - detected = chardet.detect(raw_data) - encoding = detected['encoding'] or 'utf-8' + # def _read_sync(): + # with open(file_path, 'rb') as file: + # raw_data = file.read() + # detected = chardet.detect(raw_data) + # encoding = detected['encoding'] or 'utf-8' - if mode == 'r': - return raw_data.decode(encoding, errors='ignore') - return raw_data # For binary mode + # if mode == 'r': + # return raw_data.decode(encoding, errors='ignore') + # return raw_data # For binary mode - return await self._run_sync(_read_sync) + # return await self._run_sync(_read_sync) + file_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + + detected = chardet.detect(file_bytes) + encoding = detected['encoding'] or 'utf-8' + + return file_bytes.decode(encoding, errors='ignore') # --- Specific Parser Methods --- - async def _parse_txt(self, file_path: str) -> str: + async def _parse_txt(self, file_name: str) -> str: """Parses a TXT file and returns its content.""" - self.logger.info(f'Parsing TXT file: {file_path}') - return await self._read_file_content(file_path, mode='r') + self.ap.logger.info(f'Parsing TXT file: {file_name}') + return await self._read_file_content(file_name) - async def _parse_pdf(self, file_path: str) -> str: + async def _parse_pdf(self, file_name: str) -> str: """Parses a PDF file and returns its text content.""" - self.logger.info(f'Parsing PDF file: {file_path}') + self.ap.logger.info(f'Parsing PDF file: {file_name}') + + # def _parse_pdf_sync(): + # text_content = [] + # with open(file_name, 'rb') as file: + # pdf_reader = PyPDF2.PdfReader(file) + # for page in pdf_reader.pages: + # text = page.extract_text() + # if text: + # text_content.append(text) + # return '\n'.join(text_content) + + # return await self._run_sync(_parse_pdf_sync) + + pdf_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) def _parse_pdf_sync(): + pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes)) text_content = [] - with open(file_path, 'rb') as file: - pdf_reader = PyPDF2.PdfReader(file) - for page in pdf_reader.pages: - text = page.extract_text() - if text: - text_content.append(text) + for page in pdf_reader.pages: + text = page.extract_text() + if text: + text_content.append(text) return '\n'.join(text_content) return await self._run_sync(_parse_pdf_sync) - async def _parse_docx(self, file_path: str) -> str: + async def _parse_docx(self, file_name: str) -> str: """Parses a DOCX file and returns its text content.""" - self.logger.info(f'Parsing DOCX file: {file_path}') + self.ap.logger.info(f'Parsing DOCX file: {file_name}') + + docx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) def _parse_docx_sync(): - doc = Document(file_path) + doc = Document(io.BytesIO(docx_bytes)) text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()] return '\n'.join(text_content) return await self._run_sync(_parse_docx_sync) - async def _parse_doc(self, file_path: str) -> str: + async def _parse_doc(self, file_name: str) -> str: """Handles .doc files, explicitly stating lack of direct support.""" - self.logger.warning(f'Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.') + self.ap.logger.warning(f'Direct .doc parsing is not supported for {file_name}. Please convert to .docx first.') raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.') - async def _parse_xlsx(self, file_path: str) -> str: + async def _parse_xlsx(self, file_name: str) -> str: """Parses an XLSX file, returning text from all sheets.""" - self.logger.info(f'Parsing XLSX file: {file_path}') + self.ap.logger.info(f'Parsing XLSX file: {file_name}') + + xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) def _parse_xlsx_sync(): - excel_file = pd.ExcelFile(file_path) + excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes)) all_sheet_content = [] for sheet_name in excel_file.sheet_names: - df = pd.read_excel(file_path, sheet_name=sheet_name) + df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name) sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n' all_sheet_content.append(sheet_text) return '\n'.join(all_sheet_content) return await self._run_sync(_parse_xlsx_sync) - async def _parse_csv(self, file_path: str) -> str: + async def _parse_csv(self, file_name: str) -> str: """Parses a CSV file and returns its content as a string.""" - self.logger.info(f'Parsing CSV file: {file_path}') + self.ap.logger.info(f'Parsing CSV file: {file_name}') + + csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) def _parse_csv_sync(): # pd.read_csv can often detect encoding, but explicit detection is safer - raw_data = self._read_file_content( - file_path, mode='rb' - ) # Note: this will need to be await outside this sync function - _ = raw_data + # raw_data = self._read_file_content( + # file_name, mode='rb' + # ) # Note: this will need to be await outside this sync function + # _ = raw_data # For simplicity, we'll let pandas handle encoding internally after a raw read. # A more robust solution might pass encoding directly to pd.read_csv after detection. - detected = chardet.detect(open(file_path, 'rb').read()) + detected = chardet.detect(io.BytesIO(csv_bytes)) encoding = detected['encoding'] or 'utf-8' - df = pd.read_csv(file_path, encoding=encoding) + df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding) return df.to_string(index=False) return await self._run_sync(_parse_csv_sync) - async def _parse_markdown(self, file_path: str) -> str: + async def _parse_markdown(self, file_name: str) -> str: """Parses a Markdown file, converting it to structured plain text.""" - self.logger.info(f'Parsing Markdown file: {file_path}') + self.ap.logger.info(f'Parsing Markdown file: {file_name}') + + md_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) def _parse_markdown_sync(): - md_content = self._read_file_content( - file_path, mode='r' - ) # This is a synchronous call within a sync function + md_content = io.BytesIO(md_bytes).read().decode('utf-8', errors='ignore') html_content = markdown.markdown( md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] ) @@ -200,12 +223,14 @@ class FileParser: return await self._run_sync(_parse_markdown_sync) - async def _parse_html(self, file_path: str) -> str: + async def _parse_html(self, file_name: str) -> str: """Parses an HTML file, extracting structured plain text.""" - self.logger.info(f'Parsing HTML file: {file_path}') + self.ap.logger.info(f'Parsing HTML file: {file_name}') + + html_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) def _parse_html_sync(): - html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function + html_content = io.BytesIO(html_bytes).read().decode('utf-8', errors='ignore') soup = BeautifulSoup(html_content, 'html.parser') for script_or_style in soup(['script', 'style']): script_or_style.decompose() @@ -236,12 +261,14 @@ class FileParser: return await self._run_sync(_parse_html_sync) - async def _parse_epub(self, file_path: str) -> str: + async def _parse_epub(self, file_name: str) -> str: """Parses an EPUB file, extracting metadata and content.""" - self.logger.info(f'Parsing EPUB file: {file_path}') + self.ap.logger.info(f'Parsing EPUB file: {file_name}') + + epub_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) def _parse_epub_sync(): - book = epub.read_epub(file_path) + book = epub.read_epub(io.BytesIO(epub_bytes)) text_content = [] title_meta = book.get_metadata('DC', 'title') if title_meta: diff --git a/pkg/vector/__init__.py b/pkg/vector/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/vector/mgr.py b/pkg/vector/mgr.py new file mode 100644 index 00000000..b2f47d61 --- /dev/null +++ b/pkg/vector/mgr.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from ..core import app + + +class VectorDBManager: + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass diff --git a/pkg/vector/vdb.py b/pkg/vector/vdb.py new file mode 100644 index 00000000..100ded93 --- /dev/null +++ b/pkg/vector/vdb.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +import abc + + +class VectorDatabase(abc.ABC): + pass