mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 16:04:21 +00:00
perf: ruff check --fix
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# services/retriever.py
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np # Make sure numpy is imported
|
||||
import numpy as np # Make sure numpy is imported
|
||||
from typing import List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
@@ -11,6 +10,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Retriever(BaseService):
|
||||
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
|
||||
super().__init__()
|
||||
@@ -22,10 +22,14 @@ class Retriever(BaseService):
|
||||
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
|
||||
|
||||
def _load_embedding_model(self) -> BaseEmbeddingModel:
|
||||
self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...")
|
||||
self.logger.info(
|
||||
f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...'
|
||||
)
|
||||
try:
|
||||
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
|
||||
self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
|
||||
self.logger.info(
|
||||
f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}"
|
||||
)
|
||||
return model
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}")
|
||||
@@ -33,43 +37,42 @@ class Retriever(BaseService):
|
||||
|
||||
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
||||
if not self.embedding_model:
|
||||
raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.")
|
||||
raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.')
|
||||
|
||||
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
|
||||
|
||||
query_embedding: List[float] = await self.embedding_model.embed_query(query)
|
||||
query_embedding_np = np.array([query_embedding], dtype=np.float32)
|
||||
|
||||
chroma_results = await self._run_sync(
|
||||
self.chroma_manager.search_sync,
|
||||
query_embedding_np, k
|
||||
)
|
||||
chroma_results = await self._run_sync(self.chroma_manager.search_sync, query_embedding_np, k)
|
||||
|
||||
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
|
||||
matched_chroma_ids = chroma_results.get("ids", [[]])[0]
|
||||
distances = chroma_results.get("distances", [[]])[0]
|
||||
chroma_metadatas = chroma_results.get("metadatas", [[]])[0]
|
||||
chroma_documents = chroma_results.get("documents", [[]])[0]
|
||||
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
|
||||
distances = chroma_results.get('distances', [[]])[0]
|
||||
chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
|
||||
chroma_documents = chroma_results.get('documents', [[]])[0]
|
||||
|
||||
if not matched_chroma_ids:
|
||||
self.logger.info("No relevant chunks found in Chroma.")
|
||||
self.logger.info('No relevant chunks found in Chroma.')
|
||||
return []
|
||||
|
||||
db_chunk_ids = []
|
||||
for metadata in chroma_metadatas:
|
||||
if "chunk_id" in metadata:
|
||||
db_chunk_ids.append(metadata["chunk_id"])
|
||||
if 'chunk_id' in metadata:
|
||||
db_chunk_ids.append(metadata['chunk_id'])
|
||||
else:
|
||||
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
|
||||
|
||||
if not db_chunk_ids:
|
||||
self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.")
|
||||
self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.')
|
||||
return []
|
||||
|
||||
self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...")
|
||||
self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...')
|
||||
chunks_from_db = await self._run_sync(
|
||||
lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
|
||||
db_chunk_ids
|
||||
lambda cids: self._db_get_chunks_sync(
|
||||
SessionLocal(), cids
|
||||
), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
|
||||
db_chunk_ids,
|
||||
)
|
||||
|
||||
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
|
||||
@@ -80,27 +83,29 @@ class Retriever(BaseService):
|
||||
# Ensure original_chunk_id is int for DB lookup
|
||||
original_chunk_id = int(chroma_id.split('_')[-1])
|
||||
except (ValueError, IndexError):
|
||||
self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.")
|
||||
self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.')
|
||||
continue
|
||||
|
||||
chunk_text_from_chroma = chroma_documents[i]
|
||||
distance = float(distances[i])
|
||||
file_id_from_chroma = chroma_metadatas[i].get("file_id")
|
||||
file_id_from_chroma = chroma_metadatas[i].get('file_id')
|
||||
|
||||
chunk_from_db = chunk_map.get(original_chunk_id)
|
||||
|
||||
results_list.append({
|
||||
"chunk_id": original_chunk_id,
|
||||
"text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
|
||||
"distance": distance,
|
||||
"file_id": file_id_from_chroma
|
||||
})
|
||||
results_list.append(
|
||||
{
|
||||
'chunk_id': original_chunk_id,
|
||||
'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
|
||||
'distance': distance,
|
||||
'file_id': file_id_from_chroma,
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(f"Retrieved {len(results_list)} chunks.")
|
||||
self.logger.info(f'Retrieved {len(results_list)} chunks.')
|
||||
return results_list
|
||||
|
||||
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
|
||||
self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).")
|
||||
self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).')
|
||||
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
|
||||
session.close()
|
||||
return chunks
|
||||
return chunks
|
||||
|
||||
Reference in New Issue
Block a user