feat: add embedder

This commit is contained in:
WangCham
2025-07-13 23:04:03 +08:00
parent 234b61e2f8
commit b7c57104c4
6 changed files with 43 additions and 63 deletions

View File

@@ -1,4 +1,4 @@
# services/embedder.py
from __future__ import annotations
import asyncio
import logging
import numpy as np
@@ -6,30 +6,23 @@ from typing import List
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from sqlalchemy.orm import declarative_base, sessionmaker
from ....core import app
from ....entity.persistence import model as persistence_model
import sqlalchemy
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
base = declarative_base()
logger = logging.getLogger(__name__)
class Embedder(BaseService):
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager = None):
def __init__(self, ap: app.Application, chroma_manager: ChromaIndexManager = None) -> None:
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.model_type = model_type
self.model_name_key = model_name_key
self.chroma_manager = chroma_manager # Dependency Injection
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
def _load_embedding_model(self) -> BaseEmbeddingModel:
self.logger.info(f"Loading embedding model: type={self.model_type}, name_key={self.model_name_key}...")
try:
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
self.logger.info(f"Embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
return model
except Exception as e:
self.logger.error(f"Failed to load embedding model '{self.model_name_key}': {e}")
raise
self.chroma_manager = chroma_manager
self.ap = ap
def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]):
"""
@@ -47,12 +40,10 @@ class Embedder(BaseService):
self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.")
return chunk_objects
async def embed_and_store(self, file_id: int, chunks: List[str]):
if not self.embedding_model:
async def embed_and_store(self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel) -> List[Chunk]:
if not embedding_model:
raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.")
self.logger.info(f"Embedding {len(chunks)} chunks for file_id: {file_id} using {self.model_name_key}...")
session = SessionLocal() # Start a session that will live for the whole operation
chunk_objects = []
try:
@@ -65,17 +56,23 @@ class Embedder(BaseService):
if not chunk_objects:
self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.")
return []
# 2. Generate embeddings
embeddings: List[List[float]] = await self.embedding_model.embed_documents(chunks)
# get the embeddings for the chunks
embeddings = []
i = 0
while i <len(chunks):
chunk = chunks[i]
result = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=chunk,
)
embeddings.append(result)
i += 1
embeddings_np = np.array(embeddings, dtype=np.float32)
if embeddings_np.shape[1] != self.embedding_model.embedding_dimension:
self.logger.error(f"Mismatch in embedding dimension: Model returned {embeddings_np.shape[1]}, expected {self.embedding_model.embedding_dimension}. Aborting storage.")
raise ValueError("Embedding dimension mismatch during embedding process.")
self.logger.info("Saving embeddings to Chroma...")
chunk_ids = [c.id for c in chunk_objects] # Now safe to access .id because session is still open and committed
chunk_ids = [c.id for c in chunk_objects]
file_ids_for_chroma = [file_id] * len(chunk_ids)
await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call

View File

@@ -1,39 +1,22 @@
# services/retriever.py
from __future__ import annotations
import logging
import numpy as np # Make sure numpy is imported
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from ....core import app
logger = logging.getLogger(__name__)
class Retriever(BaseService):
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
def __init__(self, ap:app.Application, chroma_manager: ChromaIndexManager):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.model_type = model_type
self.model_name_key = model_name_key
self.chroma_manager = chroma_manager
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
def _load_embedding_model(self) -> BaseEmbeddingModel:
self.logger.info(
f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...'
)
try:
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
self.logger.info(
f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}"
)
return model
except Exception as e:
self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}")
raise
self.ap = ap
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
if not self.embedding_model: