Files
LangBot/src/langbot/pkg/vector/vdbs/pgvector_db.py
2026-01-23 13:30:44 +08:00

257 lines
9.3 KiB
Python

from __future__ import annotations
from typing import Any, Dict
from sqlalchemy import create_engine, text, Column, String, Text
from sqlalchemy.orm import declarative_base
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from pgvector.sqlalchemy import Vector
from langbot.pkg.vector.vdb import VectorDatabase
from langbot.pkg.core import app
Base = declarative_base()
class PgVectorEntry(Base):
"""SQLAlchemy model for pgvector entries"""
__tablename__ = 'langbot_vectors'
id = Column(String, primary_key=True)
collection = Column(String, index=True, nullable=False)
embedding = Column(Vector(1536)) # Default dimension, will be created dynamically
text = Column(Text)
file_id = Column(String, index=True)
chunk_uuid = Column(String)
class PgVectorDatabase(VectorDatabase):
"""PostgreSQL with pgvector extension database implementation"""
def __init__(
self,
ap: app.Application,
connection_string: str = None,
host: str = 'localhost',
port: int = 5432,
database: str = 'langbot',
user: str = 'postgres',
password: str = 'postgres',
):
"""Initialize pgvector database
Args:
ap: Application instance
connection_string: Full PostgreSQL connection string (overrides other params)
host: PostgreSQL host
port: PostgreSQL port
database: Database name
user: Database user
password: Database password
"""
self.ap = ap
# Build connection string if not provided
if connection_string:
self.connection_string = connection_string
else:
self.connection_string = f'postgresql+psycopg://{user}:{password}@{host}:{port}/{database}'
self.async_connection_string = self.connection_string.replace('postgresql://', 'postgresql+asyncpg://').replace(
'postgresql+psycopg://', 'postgresql+asyncpg://'
)
self.engine = None
self.async_engine = None
self.SessionLocal = None
self.AsyncSessionLocal = None
self._collections = set()
self._initialize_db()
def _initialize_db(self):
"""Initialize database connection and create tables"""
try:
# Create async engine for async operations
self.async_engine = create_async_engine(self.async_connection_string, echo=False, pool_pre_ping=True)
self.AsyncSessionLocal = async_sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
# Create sync engine for table creation
sync_connection_string = self.connection_string.replace('postgresql+asyncpg://', 'postgresql+psycopg://')
self.engine = create_engine(sync_connection_string, echo=False)
# Create pgvector extension and tables
with self.engine.connect() as conn:
# Enable pgvector extension
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
conn.commit()
# Create tables
Base.metadata.create_all(self.engine)
self.ap.logger.info('Connected to PostgreSQL with pgvector')
except Exception as e:
self.ap.logger.error(f'Failed to connect to PostgreSQL: {e}')
raise
async def get_or_create_collection(self, collection: str):
"""Get or create a collection (logical grouping in pgvector)
Args:
collection: Collection name (knowledge base UUID)
"""
# In pgvector, collections are logical - we just track them
if collection not in self._collections:
self._collections.add(collection)
self.ap.logger.info(f"Registered pgvector collection '{collection}'")
return collection
async def add_embeddings(
self,
collection: str,
ids: list[str],
embeddings_list: list[list[float]],
metadatas: list[dict[str, Any]],
) -> None:
"""Add vector embeddings to pgvector
Args:
collection: Collection name
ids: List of unique IDs for each vector
embeddings_list: List of embedding vectors
metadatas: List of metadata dictionaries
"""
await self.get_or_create_collection(collection)
async with self.AsyncSessionLocal() as session:
try:
for i, vector_id in enumerate(ids):
metadata = metadatas[i] if i < len(metadatas) else {}
entry = PgVectorEntry(
id=vector_id,
collection=collection,
embedding=embeddings_list[i],
text=metadata.get('text', ''),
file_id=metadata.get('file_id', ''),
chunk_uuid=metadata.get('uuid', ''),
)
session.add(entry)
await session.commit()
self.ap.logger.info(f"Added {len(ids)} embeddings to pgvector collection '{collection}'")
except Exception as e:
await session.rollback()
self.ap.logger.error(f'Error adding embeddings to pgvector: {e}')
raise
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
"""Search for similar vectors using cosine distance
Args:
collection: Collection name
query_embedding: Query vector
k: Number of top results to return
Returns:
Dictionary with search results in Chroma-compatible format
"""
await self.get_or_create_collection(collection)
async with self.AsyncSessionLocal() as session:
try:
# Use cosine distance for similarity search
from sqlalchemy import select
# Query for similar vectors
stmt = (
select(
PgVectorEntry.id,
PgVectorEntry.text,
PgVectorEntry.file_id,
PgVectorEntry.chunk_uuid,
PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance'),
)
.filter(PgVectorEntry.collection == collection)
.order_by(PgVectorEntry.embedding.cosine_distance(query_embedding))
.limit(k)
)
result = await session.execute(stmt)
rows = result.fetchall()
# Convert to Chroma-compatible format
ids = []
distances = []
metadatas = []
for row in rows:
ids.append(row.id)
distances.append(float(row.distance))
metadatas.append(
{'text': row.text or '', 'file_id': row.file_id or '', 'uuid': row.chunk_uuid or ''}
)
result_dict = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
self.ap.logger.info(f"pgvector search in '{collection}' returned {len(ids)} results")
return result_dict
except Exception as e:
self.ap.logger.error(f'Error searching pgvector: {e}')
raise
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
"""Delete vectors by file_id
Args:
collection: Collection name
file_id: File ID to filter deletion
"""
await self.get_or_create_collection(collection)
async with self.AsyncSessionLocal() as session:
try:
from sqlalchemy import delete
stmt = delete(PgVectorEntry).where(
PgVectorEntry.collection == collection, PgVectorEntry.file_id == file_id
)
await session.execute(stmt)
await session.commit()
self.ap.logger.info(
f"Deleted embeddings from pgvector collection '{collection}' with file_id: {file_id}"
)
except Exception as e:
await session.rollback()
self.ap.logger.error(f'Error deleting from pgvector: {e}')
raise
async def delete_collection(self, collection: str):
"""Delete all vectors in a collection
Args:
collection: Collection name to delete
"""
if collection in self._collections:
self._collections.remove(collection)
async with self.AsyncSessionLocal() as session:
try:
from sqlalchemy import delete
stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection)
await session.execute(stmt)
await session.commit()
self.ap.logger.info(f"Deleted pgvector collection '{collection}'")
except Exception as e:
await session.rollback()
self.ap.logger.error(f'Error deleting pgvector collection: {e}')
raise
async def close(self):
"""Close database connections"""
if self.async_engine:
await self.async_engine.dispose()
if self.engine:
self.engine.dispose()