feat(rag): make embedding and retrieving available

This commit is contained in:
Junyan Qin
2025-07-16 21:17:18 +08:00
parent f731115805
commit 2f2db4d445
20 changed files with 180 additions and 368 deletions

View File

@@ -1,26 +1,15 @@
# 封装异步操作
import asyncio
import logging
from pkg.rag.knowledge.services.database import SessionLocal
class BaseService:
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.db_session_factory = SessionLocal
pass
async def _run_sync(self, func, *args, **kwargs):
"""
在单独的线程中运行同步函数。
如果第一个参数是 session则在 to_thread 中获取新的 session。
"""
if getattr(func, '__name__', '').startswith('_db_'):
session = await asyncio.to_thread(self.db_session_factory)
try:
result = await asyncio.to_thread(func, session, *args, **kwargs)
return result
finally:
session.close()
else:
# 否则,直接运行同步函数
return await asyncio.to_thread(func, *args, **kwargs)
return await asyncio.to_thread(func, *args, **kwargs)

View File

@@ -1,24 +1,21 @@
# services/chunker.py
import logging
from __future__ import annotations
from typing import List
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
from pkg.rag.knowledge.services import base_service
from pkg.core import app
logger = logging.getLogger(__name__)
class Chunker(BaseService):
class Chunker(base_service.BaseService):
"""
A class for splitting long texts into smaller, overlapping chunks.
"""
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
super().__init__(ap) # Initialize BaseService
self.ap = ap
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
if self.chunk_overlap >= self.chunk_size:
self.logger.warning(
self.ap.logger.warning(
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
)

View File

@@ -1,23 +0,0 @@
# 全部迁移过去
from pkg.entity.persistence.rag import (
create_db_and_tables,
SessionLocal,
Base,
engine,
KnowledgeBase,
File,
Chunk,
Vector,
)
__all__ = [
"create_db_and_tables",
"SessionLocal",
"Base",
"engine",
"KnowledgeBase",
"File",
"Chunk",
"Vector",
]

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
import asyncio
import numpy as np
import uuid
from typing import List
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from ....entity.persistence import rag as persistence_rag
from ....core import app
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
import sqlalchemy
class Embedder(BaseService):
@@ -14,74 +13,41 @@ class Embedder(BaseService):
super().__init__()
self.ap = ap
def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]):
"""
Saves chunks to the relational database and returns the created Chunk objects.
This function assumes it's called within a context where the session
will be committed/rolled back and closed by the caller.
"""
self.ap.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).')
chunk_objects = []
for text in chunks_texts:
chunk = Chunk(file_id=file_id, text=text)
session.add(chunk)
chunk_objects.append(chunk)
session.flush() # This populates the .id attribute for each new chunk object
self.ap.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.')
return chunk_objects
async def embed_and_store(
self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel
) -> List[Chunk]:
session = SessionLocal() # Start a session that will live for the whole operation
chunk_objects = []
try:
# 1. Save chunks to the relational database first to get their IDs
# We call _db_save_chunks_sync directly without _run_sync's session management
# because we manage the session here across multiple async calls.
chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks)
session.commit() # Commit chunks to make their IDs permanent and accessible
self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel
) -> list[persistence_rag.Chunk]:
# save chunk to db
chunk_entities: list[persistence_rag.Chunk] = []
chunk_ids: list[str] = []
if not chunk_objects:
self.ap.logger.warning(
f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.'
)
return []
for chunk_text in chunks:
chunk_uuid = str(uuid.uuid4())
chunk_ids.append(chunk_uuid)
chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text)
chunk_entities.append(chunk_entity)
# get the embeddings for the chunks
embeddings: list[list[float]] = []
chunk_dicts = [
self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities
]
for chunk in chunks:
result = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=chunk,
)
embeddings.append(result)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
embeddings_np = np.array(embeddings, dtype=np.float32)
# get embeddings
embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=chunks,
extra_args={}, # TODO: add extra args
)
chunk_ids = [c.id for c in chunk_objects]
# collection名用kb_idfile对象有kb_id字段
kb_id = session.query(Chunk).filter_by(id=chunk_ids[0]).first().file.kb_id if chunk_ids else None
if not kb_id:
self.ap.logger.warning('无法获取kb_id向量存储失败')
return chunk_objects
chroma_ids = [f'{file_id}_{cid}' for cid in chunk_ids]
metadatas = [{'file_id': file_id, 'chunk_id': cid} for cid in chunk_ids]
await self._run_sync(
self.ap.vector_db_mgr.vector_db.add_embeddings,
kb_id,
chroma_ids,
embeddings_np,
metadatas,
chunks,
)
self.ap.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to VectorDB.')
return chunk_objects
# save embeddings to vdb
await self._run_sync(
self.ap.vector_db_mgr.vector_db.add_embeddings,
kb_id,
chunk_ids,
embeddings_list,
chunk_dicts,
)
except Exception as e:
session.rollback() # Rollback on any error
self.ap.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True)
raise # Re-raise the exception to propagate it
finally:
session.close() # Ensure the session is always closed
self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')
return chunk_entities

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import PyPDF2
import io
from docx import Document

View File

@@ -1,99 +1,46 @@
from __future__ import annotations
import logging
import numpy as np # Make sure numpy is imported
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.vector.vdb import VectorDatabase
from . import base_service
from ....core import app
logger = logging.getLogger(__name__)
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
from ....entity.rag import retriever as retriever_entities
class Retriever(BaseService):
class Retriever(base_service.BaseService):
def __init__(self, ap: app.Application):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.ap = ap
self.vector_db: VectorDatabase = ap.vector_db_mgr.vector_db
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
if not self.embedding_model:
raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.')
async def retrieve(
self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
) -> list[retriever_entities.RetrieveResultEntry]:
self.ap.logger.info(f"Retrieving for query: '{query}' with k={k} using {embedding_model.model_entity.uuid}")
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=[query],
extra_args={}, # TODO: add extra args
)
query_embedding: List[float] = await self.embedding_model.embed_query(query)
query_embedding_np = np.array([query_embedding], dtype=np.float32)
# collection名用kb_id假设retriever有kb_id属性或通过ap传递
kb_id = getattr(self, 'kb_id', None)
if not kb_id:
self.logger.warning('无法获取kb_id向量检索失败')
return []
chroma_results = await self._run_sync(self.vector_db.search, kb_id, query_embedding_np, k)
chroma_results = await self._run_sync(self.ap.vector_db_mgr.vector_db.search, kb_id, query_embedding[0], k)
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
distances = chroma_results.get('distances', [[]])[0]
chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
chroma_documents = chroma_results.get('documents', [[]])[0]
if not matched_chroma_ids:
self.logger.info('No relevant chunks found in Chroma.')
self.ap.logger.info('No relevant chunks found in Chroma.')
return []
db_chunk_ids = []
for metadata in chroma_metadatas:
if 'chunk_id' in metadata:
db_chunk_ids.append(metadata['chunk_id'])
else:
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
result: list[retriever_entities.RetrieveResultEntry] = []
if not db_chunk_ids:
self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.')
return []
self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...')
chunks_from_db = await self._run_sync(
lambda cids: self._db_get_chunks_sync(
SessionLocal(), cids
), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
db_chunk_ids,
)
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
results_list: List[Dict[str, Any]] = []
for i, chroma_id in enumerate(matched_chroma_ids):
try:
# Ensure original_chunk_id is int for DB lookup
original_chunk_id = int(chroma_id.split('_')[-1])
except (ValueError, IndexError):
self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.')
continue
chunk_text_from_chroma = chroma_documents[i]
distance = float(distances[i])
file_id_from_chroma = chroma_metadatas[i].get('file_id')
chunk_from_db = chunk_map.get(original_chunk_id)
results_list.append(
{
'chunk_id': original_chunk_id,
'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
'distance': distance,
'file_id': file_id_from_chroma,
}
for i, id in enumerate(matched_chroma_ids):
entry = retriever_entities.RetrieveResultEntry(
id=id,
metadata=chroma_metadatas[i],
distance=distances[i],
)
result.append(entry)
self.logger.info(f'Retrieved {len(results_list)} chunks.')
return results_list
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).')
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
session.close()
return chunks
return result