diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index 50183f0f..c2208f6f 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -27,6 +27,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): json_data.get('name'), json_data.get('description'), json_data.get('embedding_model_uuid'), + json_data.get('top_k',5), ) return self.success(data={'uuid': knowledge_base_uuid}) diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 9742a52c..0abebfa5 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -100,7 +100,6 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): async def invoke_embedding( self, - query: core_entities.Query, model: RuntimeEmbeddingModel, input_text: str, extra_args: dict[str, typing.Any] = {}, diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 98d1f13a..5dadab7d 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -144,7 +144,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): async def invoke_embedding( self, - query: core_entities.Query, model: requester.RuntimeEmbeddingModel, input_text: str, extra_args: dict[str, typing.Any] = {}, diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 6ebc85a7..be90f6f3 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -1,4 +1,4 @@ -# rag_manager.py + from __future__ import annotations import os import asyncio @@ -10,6 +10,8 @@ from pkg.core import app from pkg.rag.knowledge.services.embedder import Embedder from pkg.rag.knowledge.services.retriever import Retriever from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager +from ...entity.persistence import model as persistence_model +import sqlalchemy class RAGManager: @@ -20,9 +22,8 @@ class RAGManager: self.chroma_manager = ChromaIndexManager() self.parser = FileParser() self.chunker = Chunker() - # Initialize Embedder with targeted model type and name - self.embedder = Embedder(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager) - self.retriever = Retriever(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager) + self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager) + self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager) async def initialize_rag_system(self): """Initializes the RAG system by creating database tables.""" @@ -55,6 +56,7 @@ class RAGManager: session.commit() session.refresh(new_kb) self.ap.logger.info(f"Knowledge Base '{kb_name}' created.") + print(embedding_model_uuid) return new_kb.id else: self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.") @@ -158,10 +160,9 @@ class RAGManager: kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() if not kb: self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ') - self.ap.logger.info(f'Created Knowledge Base with ID: {kb_id}') - else: - self.ap.logger.info(f"Knowledge Base '{kb_id}' already exists.") - + return + # get embedding model + embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(kb.embedding_model_uuid) file_name = os.path.basename(file_path) text = await self.parser.parse(file_path) if not text: @@ -172,7 +173,7 @@ class RAGManager: chunks_texts = await self.chunker.chunk(text) self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") - await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts) + await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts, embedding_model=embedding_model) self.ap.logger.info(f'Data storage process completed for file: {file_path}') except Exception as e: diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 063ae79e..4da7e82a 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -1,4 +1,4 @@ -# services/embedder.py +from __future__ import annotations import asyncio import logging import numpy as np @@ -6,30 +6,23 @@ from typing import List from sqlalchemy.orm import Session from pkg.rag.knowledge.services.base_service import BaseService from pkg.rag.knowledge.services.database import Chunk, SessionLocal -from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager +from sqlalchemy.orm import declarative_base, sessionmaker +from ....core import app +from ....entity.persistence import model as persistence_model +import sqlalchemy +from ....provider.modelmgr.requester import RuntimeEmbeddingModel + +base = declarative_base() logger = logging.getLogger(__name__) class Embedder(BaseService): - def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager = None): + def __init__(self, ap: app.Application, chroma_manager: ChromaIndexManager = None) -> None: super().__init__() self.logger = logging.getLogger(self.__class__.__name__) - self.model_type = model_type - self.model_name_key = model_name_key - self.chroma_manager = chroma_manager # Dependency Injection - - self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() - - def _load_embedding_model(self) -> BaseEmbeddingModel: - self.logger.info(f"Loading embedding model: type={self.model_type}, name_key={self.model_name_key}...") - try: - model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) - self.logger.info(f"Embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") - return model - except Exception as e: - self.logger.error(f"Failed to load embedding model '{self.model_name_key}': {e}") - raise + self.chroma_manager = chroma_manager + self.ap = ap def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]): """ @@ -47,12 +40,10 @@ class Embedder(BaseService): self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.") return chunk_objects - async def embed_and_store(self, file_id: int, chunks: List[str]): - if not self.embedding_model: + async def embed_and_store(self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel) -> List[Chunk]: + if not embedding_model: raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.") - self.logger.info(f"Embedding {len(chunks)} chunks for file_id: {file_id} using {self.model_name_key}...") - session = SessionLocal() # Start a session that will live for the whole operation chunk_objects = [] try: @@ -65,17 +56,23 @@ class Embedder(BaseService): if not chunk_objects: self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.") return [] - - # 2. Generate embeddings - embeddings: List[List[float]] = await self.embedding_model.embed_documents(chunks) + + # get the embeddings for the chunks + embeddings = [] + i = 0 + while i BaseEmbeddingModel: - self.logger.info( - f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...' - ) - try: - model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) - self.logger.info( - f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}" - ) - return model - except Exception as e: - self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}") - raise + self.ap = ap async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: if not self.embedding_model: