mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-07 06:16:02 +00:00
* feat: add milvus and pgvector as vector db * chore: update config.yaml template delete comments
250 lines
8.5 KiB
Python
250 lines
8.5 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
from typing import Any, Dict
|
|
from pymilvus import MilvusClient, DataType
|
|
from langbot.pkg.vector.vdb import VectorDatabase
|
|
from langbot.pkg.core import app
|
|
|
|
|
|
class MilvusVectorDatabase(VectorDatabase):
|
|
"""Milvus vector database implementation"""
|
|
|
|
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None):
|
|
"""Initialize Milvus vector database
|
|
|
|
Args:
|
|
ap: Application instance
|
|
uri: Milvus connection URI. For local file: "milvus.db"
|
|
For remote server: "http://localhost:19530"
|
|
token: Optional authentication token for remote connections
|
|
"""
|
|
self.ap = ap
|
|
self.uri = uri
|
|
self.token = token
|
|
self.client = None
|
|
self._collections = {}
|
|
self._initialize_client()
|
|
|
|
def _initialize_client(self):
|
|
"""Initialize Milvus client connection"""
|
|
try:
|
|
if self.token:
|
|
self.client = MilvusClient(uri=self.uri, token=self.token)
|
|
else:
|
|
self.client = MilvusClient(uri=self.uri)
|
|
self.ap.logger.info(f"Connected to Milvus at {self.uri}")
|
|
except Exception as e:
|
|
self.ap.logger.error(f"Failed to connect to Milvus: {e}")
|
|
raise
|
|
|
|
async def get_or_create_collection(self, collection: str):
|
|
"""Get or create a Milvus collection
|
|
|
|
Args:
|
|
collection: Collection name (corresponds to knowledge base UUID)
|
|
"""
|
|
if collection in self._collections:
|
|
return self._collections[collection]
|
|
|
|
# Check if collection exists
|
|
has_collection = await asyncio.to_thread(
|
|
self.client.has_collection, collection_name=collection
|
|
)
|
|
|
|
if not has_collection:
|
|
# Create collection with custom schema to support string IDs
|
|
from pymilvus import CollectionSchema, FieldSchema, DataType
|
|
|
|
fields = [
|
|
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
|
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1536),
|
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
|
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255),
|
|
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255),
|
|
]
|
|
|
|
schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors")
|
|
|
|
await asyncio.to_thread(
|
|
self.client.create_collection,
|
|
collection_name=collection,
|
|
schema=schema,
|
|
metric_type="COSINE",
|
|
)
|
|
|
|
# Create index for vector field (required for loading/searching)
|
|
index_params = {
|
|
"metric_type": "COSINE",
|
|
"index_type": "AUTOINDEX",
|
|
"params": {}
|
|
}
|
|
await asyncio.to_thread(
|
|
self.client.create_index,
|
|
collection_name=collection,
|
|
field_name="vector",
|
|
index_params=index_params
|
|
)
|
|
|
|
self.ap.logger.info(f"Created Milvus collection '{collection}' with index")
|
|
else:
|
|
self.ap.logger.info(f"Milvus collection '{collection}' already exists")
|
|
|
|
self._collections[collection] = collection
|
|
return collection
|
|
|
|
async def add_embeddings(
|
|
self,
|
|
collection: str,
|
|
ids: list[str],
|
|
embeddings_list: list[list[float]],
|
|
metadatas: list[dict[str, Any]],
|
|
) -> None:
|
|
"""Add vector embeddings to Milvus collection
|
|
|
|
Args:
|
|
collection: Collection name
|
|
ids: List of unique IDs for each vector
|
|
embeddings_list: List of embedding vectors
|
|
metadatas: List of metadata dictionaries for each vector
|
|
"""
|
|
await self.get_or_create_collection(collection)
|
|
|
|
# Prepare data in Milvus format
|
|
data = []
|
|
for i, vector_id in enumerate(ids):
|
|
entry = {
|
|
"id": vector_id,
|
|
"vector": embeddings_list[i],
|
|
}
|
|
# Add metadata fields
|
|
if metadatas and i < len(metadatas):
|
|
metadata = metadatas[i]
|
|
# Add common metadata fields
|
|
if "text" in metadata:
|
|
entry["text"] = metadata["text"]
|
|
if "file_id" in metadata:
|
|
entry["file_id"] = metadata["file_id"]
|
|
if "uuid" in metadata:
|
|
entry["chunk_uuid"] = metadata["uuid"]
|
|
data.append(entry)
|
|
|
|
# Insert data into Milvus
|
|
await asyncio.to_thread(
|
|
self.client.insert,
|
|
collection_name=collection,
|
|
data=data
|
|
)
|
|
|
|
# Load collection for searching (Milvus requires this)
|
|
await asyncio.to_thread(
|
|
self.client.load_collection,
|
|
collection_name=collection
|
|
)
|
|
|
|
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
|
|
|
async def search(
|
|
self, collection: str, query_embedding: list[float], k: int = 5
|
|
) -> Dict[str, Any]:
|
|
"""Search for similar vectors in Milvus collection
|
|
|
|
Args:
|
|
collection: Collection name
|
|
query_embedding: Query vector
|
|
k: Number of top results to return
|
|
|
|
Returns:
|
|
Dictionary with search results in Chroma-compatible format
|
|
"""
|
|
await self.get_or_create_collection(collection)
|
|
|
|
# Perform search
|
|
search_params = {
|
|
"metric_type": "COSINE",
|
|
"params": {}
|
|
}
|
|
|
|
results = await asyncio.to_thread(
|
|
self.client.search,
|
|
collection_name=collection,
|
|
data=[query_embedding],
|
|
limit=k,
|
|
search_params=search_params,
|
|
output_fields=["text", "file_id", "chunk_uuid"]
|
|
)
|
|
|
|
# Convert results to Chroma-compatible format
|
|
# Milvus returns: [[ {id, distance, entity: {...}} ]]
|
|
ids = []
|
|
distances = []
|
|
metadatas = []
|
|
|
|
if results and len(results) > 0:
|
|
for hit in results[0]:
|
|
ids.append(hit.get("id", ""))
|
|
distances.append(hit.get("distance", 0.0))
|
|
|
|
# Build metadata from entity fields
|
|
entity = hit.get("entity", {})
|
|
metadata = {}
|
|
if "text" in entity:
|
|
metadata["text"] = entity["text"]
|
|
if "file_id" in entity:
|
|
metadata["file_id"] = entity["file_id"]
|
|
if "chunk_uuid" in entity:
|
|
metadata["uuid"] = entity["chunk_uuid"]
|
|
metadatas.append(metadata)
|
|
|
|
# Return in Chroma-compatible format (nested lists)
|
|
result = {
|
|
"ids": [ids],
|
|
"distances": [distances],
|
|
"metadatas": [metadatas]
|
|
}
|
|
|
|
self.ap.logger.info(
|
|
f"Milvus search in '{collection}' returned {len(ids)} results"
|
|
)
|
|
return result
|
|
|
|
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
|
"""Delete vectors from collection by file_id
|
|
|
|
Args:
|
|
collection: Collection name
|
|
file_id: File ID to filter deletion
|
|
"""
|
|
await self.get_or_create_collection(collection)
|
|
|
|
# Delete entities matching the file_id
|
|
await asyncio.to_thread(
|
|
self.client.delete,
|
|
collection_name=collection,
|
|
filter=f'file_id == "{file_id}"'
|
|
)
|
|
self.ap.logger.info(
|
|
f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}"
|
|
)
|
|
|
|
async def delete_collection(self, collection: str):
|
|
"""Delete a Milvus collection
|
|
|
|
Args:
|
|
collection: Collection name to delete
|
|
"""
|
|
if collection in self._collections:
|
|
del self._collections[collection]
|
|
|
|
# Check if collection exists before attempting deletion
|
|
has_collection = await asyncio.to_thread(
|
|
self.client.has_collection, collection_name=collection
|
|
)
|
|
|
|
if has_collection:
|
|
await asyncio.to_thread(
|
|
self.client.drop_collection, collection_name=collection
|
|
)
|
|
self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
|
|
else:
|
|
self.ap.logger.warning(f"Milvus collection '{collection}' not found")
|