mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-10 15:56:03 +00:00
feat: add milvus and pgvector as vector db (#1840)
* feat: add milvus and pgvector as vector db * chore: update config.yaml template delete comments
This commit is contained in:
committed by
GitHub
parent
6bf08466de
commit
86e951916e
286
src/langbot/pkg/vector/vdbs/pgvector_db.py
Normal file
286
src/langbot/pkg/vector/vdbs/pgvector_db.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from sqlalchemy import create_engine, text, Column, String, Text
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.core import app
|
||||
import uuid
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class PgVectorEntry(Base):
|
||||
"""SQLAlchemy model for pgvector entries"""
|
||||
__tablename__ = 'langbot_vectors'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
collection = Column(String, index=True, nullable=False)
|
||||
embedding = Column(Vector(1536)) # Default dimension, will be created dynamically
|
||||
text = Column(Text)
|
||||
file_id = Column(String, index=True)
|
||||
chunk_uuid = Column(String)
|
||||
|
||||
|
||||
class PgVectorDatabase(VectorDatabase):
|
||||
"""PostgreSQL with pgvector extension database implementation"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
connection_string: str = None,
|
||||
host: str = "localhost",
|
||||
port: int = 5432,
|
||||
database: str = "langbot",
|
||||
user: str = "postgres",
|
||||
password: str = "postgres"
|
||||
):
|
||||
"""Initialize pgvector database
|
||||
|
||||
Args:
|
||||
ap: Application instance
|
||||
connection_string: Full PostgreSQL connection string (overrides other params)
|
||||
host: PostgreSQL host
|
||||
port: PostgreSQL port
|
||||
database: Database name
|
||||
user: Database user
|
||||
password: Database password
|
||||
"""
|
||||
self.ap = ap
|
||||
|
||||
# Build connection string if not provided
|
||||
if connection_string:
|
||||
self.connection_string = connection_string
|
||||
else:
|
||||
self.connection_string = (
|
||||
f"postgresql+psycopg://{user}:{password}@{host}:{port}/{database}"
|
||||
)
|
||||
|
||||
self.async_connection_string = self.connection_string.replace(
|
||||
"postgresql://", "postgresql+asyncpg://"
|
||||
).replace(
|
||||
"postgresql+psycopg://", "postgresql+asyncpg://"
|
||||
)
|
||||
|
||||
self.engine = None
|
||||
self.async_engine = None
|
||||
self.SessionLocal = None
|
||||
self.AsyncSessionLocal = None
|
||||
self._collections = set()
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""Initialize database connection and create tables"""
|
||||
try:
|
||||
# Create async engine for async operations
|
||||
self.async_engine = create_async_engine(
|
||||
self.async_connection_string,
|
||||
echo=False,
|
||||
pool_pre_ping=True
|
||||
)
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
# Create sync engine for table creation
|
||||
sync_connection_string = self.connection_string.replace(
|
||||
"postgresql+asyncpg://", "postgresql+psycopg://"
|
||||
)
|
||||
self.engine = create_engine(sync_connection_string, echo=False)
|
||||
|
||||
# Create pgvector extension and tables
|
||||
with self.engine.connect() as conn:
|
||||
# Enable pgvector extension
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
conn.commit()
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
self.ap.logger.info(f"Connected to PostgreSQL with pgvector")
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||
raise
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create a collection (logical grouping in pgvector)
|
||||
|
||||
Args:
|
||||
collection: Collection name (knowledge base UUID)
|
||||
"""
|
||||
# In pgvector, collections are logical - we just track them
|
||||
if collection not in self._collections:
|
||||
self._collections.add(collection)
|
||||
self.ap.logger.info(f"Registered pgvector collection '{collection}'")
|
||||
return collection
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Add vector embeddings to pgvector
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
ids: List of unique IDs for each vector
|
||||
embeddings_list: List of embedding vectors
|
||||
metadatas: List of metadata dictionaries
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
for i, vector_id in enumerate(ids):
|
||||
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||
|
||||
entry = PgVectorEntry(
|
||||
id=vector_id,
|
||||
collection=collection,
|
||||
embedding=embeddings_list[i],
|
||||
text=metadata.get("text", ""),
|
||||
file_id=metadata.get("file_id", ""),
|
||||
chunk_uuid=metadata.get("uuid", "")
|
||||
)
|
||||
session.add(entry)
|
||||
|
||||
await session.commit()
|
||||
self.ap.logger.info(
|
||||
f"Added {len(ids)} embeddings to pgvector collection '{collection}'"
|
||||
)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error adding embeddings to pgvector: {e}")
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self, collection: str, query_embedding: list[float], k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for similar vectors using cosine distance
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
query_embedding: Query vector
|
||||
k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
Dictionary with search results in Chroma-compatible format
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
# Use cosine distance for similarity search
|
||||
from sqlalchemy import select, func
|
||||
|
||||
# Query for similar vectors
|
||||
stmt = (
|
||||
select(
|
||||
PgVectorEntry.id,
|
||||
PgVectorEntry.text,
|
||||
PgVectorEntry.file_id,
|
||||
PgVectorEntry.chunk_uuid,
|
||||
PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance')
|
||||
)
|
||||
.filter(PgVectorEntry.collection == collection)
|
||||
.order_by(PgVectorEntry.embedding.cosine_distance(query_embedding))
|
||||
.limit(k)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert to Chroma-compatible format
|
||||
ids = []
|
||||
distances = []
|
||||
metadatas = []
|
||||
|
||||
for row in rows:
|
||||
ids.append(row.id)
|
||||
distances.append(float(row.distance))
|
||||
metadatas.append({
|
||||
"text": row.text or "",
|
||||
"file_id": row.file_id or "",
|
||||
"uuid": row.chunk_uuid or ""
|
||||
})
|
||||
|
||||
result_dict = {
|
||||
"ids": [ids],
|
||||
"distances": [distances],
|
||||
"metadatas": [metadatas]
|
||||
}
|
||||
|
||||
self.ap.logger.info(
|
||||
f"pgvector search in '{collection}' returned {len(ids)} results"
|
||||
)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Error searching pgvector: {e}")
|
||||
raise
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
"""Delete vectors by file_id
|
||||
|
||||
Args:
|
||||
collection: Collection name
|
||||
file_id: File ID to filter deletion
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(
|
||||
PgVectorEntry.collection == collection,
|
||||
PgVectorEntry.file_id == file_id
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
self.ap.logger.info(
|
||||
f"Deleted embeddings from pgvector collection '{collection}' with file_id: {file_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error deleting from pgvector: {e}")
|
||||
raise
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete all vectors in a collection
|
||||
|
||||
Args:
|
||||
collection: Collection name to delete
|
||||
"""
|
||||
if collection in self._collections:
|
||||
self._collections.remove(collection)
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(
|
||||
PgVectorEntry.collection == collection
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
self.ap.logger.info(f"Deleted pgvector collection '{collection}'")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error deleting pgvector collection: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close database connections"""
|
||||
if self.async_engine:
|
||||
await self.async_engine.dispose()
|
||||
if self.engine:
|
||||
self.engine.dispose()
|
||||
Reference in New Issue
Block a user