from __future__ import annotations from typing import Any, Dict from sqlalchemy import create_engine, text, Column, String, Text from sqlalchemy.orm import declarative_base from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from pgvector.sqlalchemy import Vector from langbot.pkg.vector.vdb import VectorDatabase from langbot.pkg.vector.filter_utils import normalize_filter, strip_unsupported_fields from langbot.pkg.core import app Base = declarative_base() # pgvector schema only stores these metadata fields. _PG_SUPPORTED_FIELDS = {'text', 'file_id', 'chunk_uuid'} # Callers use canonical metadata key 'uuid' but pgvector stores it as 'chunk_uuid'. _PG_FIELD_ALIASES = {'uuid': 'chunk_uuid'} # Map schema field names to SQLAlchemy columns (resolved lazily from PgVectorEntry). _PG_COLUMN_MAP = { 'text': 'text', 'file_id': 'file_id', 'chunk_uuid': 'chunk_uuid', } class PgVectorEntry(Base): """SQLAlchemy model for pgvector entries""" __tablename__ = 'langbot_vectors' id = Column(String, primary_key=True) collection = Column(String, index=True, nullable=False) embedding = Column(Vector(1536)) # Default dimension, will be created dynamically text = Column(Text) file_id = Column(String, index=True) chunk_uuid = Column(String) def _build_pg_conditions(filter_dict: dict[str, Any]) -> list: """Translate canonical filter dict into a list of SQLAlchemy conditions.""" triples = normalize_filter(filter_dict) triples = strip_unsupported_fields(triples, _PG_SUPPORTED_FIELDS, _PG_FIELD_ALIASES) conditions = [] for field, op, value in triples: col = getattr(PgVectorEntry, _PG_COLUMN_MAP[field]) if op == '$eq': conditions.append(col == value) elif op == '$ne': conditions.append(col != value) elif op == '$gt': conditions.append(col > value) elif op == '$gte': conditions.append(col >= value) elif op == '$lt': conditions.append(col < value) elif op == '$lte': conditions.append(col <= value) elif op == '$in': conditions.append(col.in_(value)) elif op == '$nin': conditions.append(col.notin_(value)) return conditions class PgVectorDatabase(VectorDatabase): """PostgreSQL with pgvector extension database implementation""" def __init__( self, ap: app.Application, connection_string: str = None, host: str = 'localhost', port: int = 5432, database: str = 'langbot', user: str = 'postgres', password: str = 'postgres', ): """Initialize pgvector database Args: ap: Application instance connection_string: Full PostgreSQL connection string (overrides other params) host: PostgreSQL host port: PostgreSQL port database: Database name user: Database user password: Database password """ self.ap = ap # Build connection string if not provided if connection_string: self.connection_string = connection_string else: self.connection_string = f'postgresql+psycopg://{user}:{password}@{host}:{port}/{database}' self.async_connection_string = self.connection_string.replace('postgresql://', 'postgresql+asyncpg://').replace( 'postgresql+psycopg://', 'postgresql+asyncpg://' ) self.engine = None self.async_engine = None self.SessionLocal = None self.AsyncSessionLocal = None self._collections = set() self._initialize_db() def _initialize_db(self): """Initialize database connection and create tables""" try: # Create async engine for async operations self.async_engine = create_async_engine(self.async_connection_string, echo=False, pool_pre_ping=True) self.AsyncSessionLocal = async_sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False) # Create sync engine for table creation sync_connection_string = self.connection_string.replace('postgresql+asyncpg://', 'postgresql+psycopg://') self.engine = create_engine(sync_connection_string, echo=False) # Create pgvector extension and tables with self.engine.connect() as conn: # Enable pgvector extension conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) conn.commit() # Create tables Base.metadata.create_all(self.engine) self.ap.logger.info('Connected to PostgreSQL with pgvector') except Exception as e: self.ap.logger.error(f'Failed to connect to PostgreSQL: {e}') raise async def get_or_create_collection(self, collection: str): """Get or create a collection (logical grouping in pgvector) Args: collection: Collection name (knowledge base UUID) """ # In pgvector, collections are logical - we just track them if collection not in self._collections: self._collections.add(collection) self.ap.logger.info(f"Registered pgvector collection '{collection}'") return collection async def add_embeddings( self, collection: str, ids: list[str], embeddings_list: list[list[float]], metadatas: list[dict[str, Any]], documents: list[str] | None = None, ) -> None: """Add vector embeddings to pgvector Args: collection: Collection name ids: List of unique IDs for each vector embeddings_list: List of embedding vectors metadatas: List of metadata dictionaries """ await self.get_or_create_collection(collection) async with self.AsyncSessionLocal() as session: try: for i, vector_id in enumerate(ids): metadata = metadatas[i] if i < len(metadatas) else {} entry = PgVectorEntry( id=vector_id, collection=collection, embedding=embeddings_list[i], text=metadata.get('text', ''), file_id=metadata.get('file_id', ''), chunk_uuid=metadata.get('uuid', ''), ) session.add(entry) await session.commit() self.ap.logger.info(f"Added {len(ids)} embeddings to pgvector collection '{collection}'") except Exception as e: await session.rollback() self.ap.logger.error(f'Error adding embeddings to pgvector: {e}') raise async def search( self, collection: str, query_embedding: list[float], k: int = 5, search_type: str = 'vector', query_text: str = '', filter: dict[str, Any] | None = None, vector_weight: float | None = None, ) -> Dict[str, Any]: """Search for similar vectors using cosine distance Args: collection: Collection name query_embedding: Query vector k: Number of top results to return Returns: Dictionary with search results in Chroma-compatible format """ await self.get_or_create_collection(collection) async with self.AsyncSessionLocal() as session: try: # Use cosine distance for similarity search from sqlalchemy import select # Query for similar vectors stmt = ( select( PgVectorEntry.id, PgVectorEntry.text, PgVectorEntry.file_id, PgVectorEntry.chunk_uuid, PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance'), ) .filter(PgVectorEntry.collection == collection) .order_by(PgVectorEntry.embedding.cosine_distance(query_embedding)) .limit(k) ) if filter: for cond in _build_pg_conditions(filter): stmt = stmt.filter(cond) result = await session.execute(stmt) rows = result.fetchall() # Convert to Chroma-compatible format ids = [] distances = [] metadatas = [] for row in rows: ids.append(row.id) distances.append(float(row.distance)) metadatas.append( {'text': row.text or '', 'file_id': row.file_id or '', 'uuid': row.chunk_uuid or ''} ) result_dict = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]} self.ap.logger.info(f"pgvector search in '{collection}' returned {len(ids)} results") return result_dict except Exception as e: self.ap.logger.error(f'Error searching pgvector: {e}') raise async def delete_by_file_id(self, collection: str, file_id: str) -> None: """Delete vectors by file_id Args: collection: Collection name file_id: File ID to filter deletion """ await self.get_or_create_collection(collection) async with self.AsyncSessionLocal() as session: try: from sqlalchemy import delete stmt = delete(PgVectorEntry).where( PgVectorEntry.collection == collection, PgVectorEntry.file_id == file_id ) await session.execute(stmt) await session.commit() self.ap.logger.info( f"Deleted embeddings from pgvector collection '{collection}' with file_id: {file_id}" ) except Exception as e: await session.rollback() self.ap.logger.error(f'Error deleting from pgvector: {e}') raise async def delete_by_filter(self, collection: str, filter: dict[str, Any]) -> int: """Delete vectors matching a metadata filter. Args: collection: Collection name filter: Canonical metadata filter dict """ conditions = _build_pg_conditions(filter) if not conditions: self.ap.logger.warning( f"pgvector delete_by_filter on '{collection}': filter produced no conditions, skipping" ) return 0 await self.get_or_create_collection(collection) async with self.AsyncSessionLocal() as session: try: from sqlalchemy import delete stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection) for cond in conditions: stmt = stmt.where(cond) result = await session.execute(stmt) await session.commit() deleted = result.rowcount self.ap.logger.info(f"Deleted {deleted} embeddings from pgvector collection '{collection}' by filter") return deleted except Exception as e: await session.rollback() self.ap.logger.error(f'Error deleting from pgvector by filter: {e}') raise async def list_by_filter( self, collection: str, filter: dict[str, Any] | None = None, limit: int = 20, offset: int = 0, ) -> tuple[list[dict[str, Any]], int]: await self.get_or_create_collection(collection) async with self.AsyncSessionLocal() as session: try: from sqlalchemy import select, func stmt = ( select( PgVectorEntry.id, PgVectorEntry.text, PgVectorEntry.file_id, PgVectorEntry.chunk_uuid, ) .filter(PgVectorEntry.collection == collection) .offset(offset) .limit(limit) ) count_stmt = ( select(func.count()).select_from(PgVectorEntry).filter(PgVectorEntry.collection == collection) ) if filter: for cond in _build_pg_conditions(filter): stmt = stmt.filter(cond) count_stmt = count_stmt.filter(cond) result = await session.execute(stmt) rows = result.fetchall() count_result = await session.execute(count_stmt) total = count_result.scalar() or 0 items = [] for row in rows: items.append( { 'id': row.id, 'document': row.text or '', 'metadata': { 'text': row.text or '', 'file_id': row.file_id or '', 'uuid': row.chunk_uuid or '', }, } ) return items, total except Exception as e: self.ap.logger.error(f'Error listing from pgvector: {e}') raise async def delete_collection(self, collection: str): """Delete all vectors in a collection Args: collection: Collection name to delete """ if collection in self._collections: self._collections.remove(collection) async with self.AsyncSessionLocal() as session: try: from sqlalchemy import delete stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection) await session.execute(stmt) await session.commit() self.ap.logger.info(f"Deleted pgvector collection '{collection}'") except Exception as e: await session.rollback() self.ap.logger.error(f'Error deleting pgvector collection: {e}') raise async def close(self): """Close database connections""" if self.async_engine: await self.async_engine.dispose() if self.engine: self.engine.dispose()