mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 20:14:36 +00:00
257 lines
9.3 KiB
Python
257 lines
9.3 KiB
Python
from __future__ import annotations
|
|
from typing import Any, Dict
|
|
from sqlalchemy import create_engine, text, Column, String, Text
|
|
from sqlalchemy.orm import declarative_base
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
from pgvector.sqlalchemy import Vector
|
|
from langbot.pkg.vector.vdb import VectorDatabase
|
|
from langbot.pkg.core import app
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class PgVectorEntry(Base):
|
|
"""SQLAlchemy model for pgvector entries"""
|
|
|
|
__tablename__ = 'langbot_vectors'
|
|
|
|
id = Column(String, primary_key=True)
|
|
collection = Column(String, index=True, nullable=False)
|
|
embedding = Column(Vector(1536)) # Default dimension, will be created dynamically
|
|
text = Column(Text)
|
|
file_id = Column(String, index=True)
|
|
chunk_uuid = Column(String)
|
|
|
|
|
|
class PgVectorDatabase(VectorDatabase):
|
|
"""PostgreSQL with pgvector extension database implementation"""
|
|
|
|
def __init__(
|
|
self,
|
|
ap: app.Application,
|
|
connection_string: str = None,
|
|
host: str = 'localhost',
|
|
port: int = 5432,
|
|
database: str = 'langbot',
|
|
user: str = 'postgres',
|
|
password: str = 'postgres',
|
|
):
|
|
"""Initialize pgvector database
|
|
|
|
Args:
|
|
ap: Application instance
|
|
connection_string: Full PostgreSQL connection string (overrides other params)
|
|
host: PostgreSQL host
|
|
port: PostgreSQL port
|
|
database: Database name
|
|
user: Database user
|
|
password: Database password
|
|
"""
|
|
self.ap = ap
|
|
|
|
# Build connection string if not provided
|
|
if connection_string:
|
|
self.connection_string = connection_string
|
|
else:
|
|
self.connection_string = f'postgresql+psycopg://{user}:{password}@{host}:{port}/{database}'
|
|
|
|
self.async_connection_string = self.connection_string.replace('postgresql://', 'postgresql+asyncpg://').replace(
|
|
'postgresql+psycopg://', 'postgresql+asyncpg://'
|
|
)
|
|
|
|
self.engine = None
|
|
self.async_engine = None
|
|
self.SessionLocal = None
|
|
self.AsyncSessionLocal = None
|
|
self._collections = set()
|
|
self._initialize_db()
|
|
|
|
def _initialize_db(self):
|
|
"""Initialize database connection and create tables"""
|
|
try:
|
|
# Create async engine for async operations
|
|
self.async_engine = create_async_engine(self.async_connection_string, echo=False, pool_pre_ping=True)
|
|
self.AsyncSessionLocal = async_sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
# Create sync engine for table creation
|
|
sync_connection_string = self.connection_string.replace('postgresql+asyncpg://', 'postgresql+psycopg://')
|
|
self.engine = create_engine(sync_connection_string, echo=False)
|
|
|
|
# Create pgvector extension and tables
|
|
with self.engine.connect() as conn:
|
|
# Enable pgvector extension
|
|
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
|
conn.commit()
|
|
|
|
# Create tables
|
|
Base.metadata.create_all(self.engine)
|
|
|
|
self.ap.logger.info('Connected to PostgreSQL with pgvector')
|
|
except Exception as e:
|
|
self.ap.logger.error(f'Failed to connect to PostgreSQL: {e}')
|
|
raise
|
|
|
|
async def get_or_create_collection(self, collection: str):
|
|
"""Get or create a collection (logical grouping in pgvector)
|
|
|
|
Args:
|
|
collection: Collection name (knowledge base UUID)
|
|
"""
|
|
# In pgvector, collections are logical - we just track them
|
|
if collection not in self._collections:
|
|
self._collections.add(collection)
|
|
self.ap.logger.info(f"Registered pgvector collection '{collection}'")
|
|
return collection
|
|
|
|
async def add_embeddings(
|
|
self,
|
|
collection: str,
|
|
ids: list[str],
|
|
embeddings_list: list[list[float]],
|
|
metadatas: list[dict[str, Any]],
|
|
) -> None:
|
|
"""Add vector embeddings to pgvector
|
|
|
|
Args:
|
|
collection: Collection name
|
|
ids: List of unique IDs for each vector
|
|
embeddings_list: List of embedding vectors
|
|
metadatas: List of metadata dictionaries
|
|
"""
|
|
await self.get_or_create_collection(collection)
|
|
|
|
async with self.AsyncSessionLocal() as session:
|
|
try:
|
|
for i, vector_id in enumerate(ids):
|
|
metadata = metadatas[i] if i < len(metadatas) else {}
|
|
|
|
entry = PgVectorEntry(
|
|
id=vector_id,
|
|
collection=collection,
|
|
embedding=embeddings_list[i],
|
|
text=metadata.get('text', ''),
|
|
file_id=metadata.get('file_id', ''),
|
|
chunk_uuid=metadata.get('uuid', ''),
|
|
)
|
|
session.add(entry)
|
|
|
|
await session.commit()
|
|
self.ap.logger.info(f"Added {len(ids)} embeddings to pgvector collection '{collection}'")
|
|
except Exception as e:
|
|
await session.rollback()
|
|
self.ap.logger.error(f'Error adding embeddings to pgvector: {e}')
|
|
raise
|
|
|
|
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
|
"""Search for similar vectors using cosine distance
|
|
|
|
Args:
|
|
collection: Collection name
|
|
query_embedding: Query vector
|
|
k: Number of top results to return
|
|
|
|
Returns:
|
|
Dictionary with search results in Chroma-compatible format
|
|
"""
|
|
await self.get_or_create_collection(collection)
|
|
|
|
async with self.AsyncSessionLocal() as session:
|
|
try:
|
|
# Use cosine distance for similarity search
|
|
from sqlalchemy import select
|
|
|
|
# Query for similar vectors
|
|
stmt = (
|
|
select(
|
|
PgVectorEntry.id,
|
|
PgVectorEntry.text,
|
|
PgVectorEntry.file_id,
|
|
PgVectorEntry.chunk_uuid,
|
|
PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance'),
|
|
)
|
|
.filter(PgVectorEntry.collection == collection)
|
|
.order_by(PgVectorEntry.embedding.cosine_distance(query_embedding))
|
|
.limit(k)
|
|
)
|
|
|
|
result = await session.execute(stmt)
|
|
rows = result.fetchall()
|
|
|
|
# Convert to Chroma-compatible format
|
|
ids = []
|
|
distances = []
|
|
metadatas = []
|
|
|
|
for row in rows:
|
|
ids.append(row.id)
|
|
distances.append(float(row.distance))
|
|
metadatas.append(
|
|
{'text': row.text or '', 'file_id': row.file_id or '', 'uuid': row.chunk_uuid or ''}
|
|
)
|
|
|
|
result_dict = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
|
|
|
|
self.ap.logger.info(f"pgvector search in '{collection}' returned {len(ids)} results")
|
|
return result_dict
|
|
|
|
except Exception as e:
|
|
self.ap.logger.error(f'Error searching pgvector: {e}')
|
|
raise
|
|
|
|
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
|
"""Delete vectors by file_id
|
|
|
|
Args:
|
|
collection: Collection name
|
|
file_id: File ID to filter deletion
|
|
"""
|
|
await self.get_or_create_collection(collection)
|
|
|
|
async with self.AsyncSessionLocal() as session:
|
|
try:
|
|
from sqlalchemy import delete
|
|
|
|
stmt = delete(PgVectorEntry).where(
|
|
PgVectorEntry.collection == collection, PgVectorEntry.file_id == file_id
|
|
)
|
|
await session.execute(stmt)
|
|
await session.commit()
|
|
|
|
self.ap.logger.info(
|
|
f"Deleted embeddings from pgvector collection '{collection}' with file_id: {file_id}"
|
|
)
|
|
except Exception as e:
|
|
await session.rollback()
|
|
self.ap.logger.error(f'Error deleting from pgvector: {e}')
|
|
raise
|
|
|
|
async def delete_collection(self, collection: str):
|
|
"""Delete all vectors in a collection
|
|
|
|
Args:
|
|
collection: Collection name to delete
|
|
"""
|
|
if collection in self._collections:
|
|
self._collections.remove(collection)
|
|
|
|
async with self.AsyncSessionLocal() as session:
|
|
try:
|
|
from sqlalchemy import delete
|
|
|
|
stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection)
|
|
await session.execute(stmt)
|
|
await session.commit()
|
|
|
|
self.ap.logger.info(f"Deleted pgvector collection '{collection}'")
|
|
except Exception as e:
|
|
await session.rollback()
|
|
self.ap.logger.error(f'Error deleting pgvector collection: {e}')
|
|
raise
|
|
|
|
async def close(self):
|
|
"""Close database connections"""
|
|
if self.async_engine:
|
|
await self.async_engine.dispose()
|
|
if self.engine:
|
|
self.engine.dispose()
|