from __future__ import annotations import asyncio from typing import Any, Dict from pymilvus import MilvusClient, DataType from langbot.pkg.vector.vdb import VectorDatabase from langbot.pkg.core import app class MilvusVectorDatabase(VectorDatabase): """Milvus vector database implementation""" def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None): """Initialize Milvus vector database Args: ap: Application instance uri: Milvus connection URI. For local file: "milvus.db" For remote server: "http://localhost:19530" token: Optional authentication token for remote connections """ self.ap = ap self.uri = uri self.token = token self.client = None self._collections = {} self._initialize_client() def _initialize_client(self): """Initialize Milvus client connection""" try: if self.token: self.client = MilvusClient(uri=self.uri, token=self.token) else: self.client = MilvusClient(uri=self.uri) self.ap.logger.info(f"Connected to Milvus at {self.uri}") except Exception as e: self.ap.logger.error(f"Failed to connect to Milvus: {e}") raise async def get_or_create_collection(self, collection: str): """Get or create a Milvus collection Args: collection: Collection name (corresponds to knowledge base UUID) """ if collection in self._collections: return self._collections[collection] # Check if collection exists has_collection = await asyncio.to_thread( self.client.has_collection, collection_name=collection ) if not has_collection: # Create collection with custom schema to support string IDs from pymilvus import CollectionSchema, FieldSchema, DataType fields = [ FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255), FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1536), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255), FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255), ] schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors") await asyncio.to_thread( self.client.create_collection, collection_name=collection, schema=schema, metric_type="COSINE", ) # Create index for vector field (required for loading/searching) index_params = { "metric_type": "COSINE", "index_type": "AUTOINDEX", "params": {} } await asyncio.to_thread( self.client.create_index, collection_name=collection, field_name="vector", index_params=index_params ) self.ap.logger.info(f"Created Milvus collection '{collection}' with index") else: self.ap.logger.info(f"Milvus collection '{collection}' already exists") self._collections[collection] = collection return collection async def add_embeddings( self, collection: str, ids: list[str], embeddings_list: list[list[float]], metadatas: list[dict[str, Any]], ) -> None: """Add vector embeddings to Milvus collection Args: collection: Collection name ids: List of unique IDs for each vector embeddings_list: List of embedding vectors metadatas: List of metadata dictionaries for each vector """ await self.get_or_create_collection(collection) # Prepare data in Milvus format data = [] for i, vector_id in enumerate(ids): entry = { "id": vector_id, "vector": embeddings_list[i], } # Add metadata fields if metadatas and i < len(metadatas): metadata = metadatas[i] # Add common metadata fields if "text" in metadata: entry["text"] = metadata["text"] if "file_id" in metadata: entry["file_id"] = metadata["file_id"] if "uuid" in metadata: entry["chunk_uuid"] = metadata["uuid"] data.append(entry) # Insert data into Milvus await asyncio.to_thread( self.client.insert, collection_name=collection, data=data ) # Load collection for searching (Milvus requires this) await asyncio.to_thread( self.client.load_collection, collection_name=collection ) self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'") async def search( self, collection: str, query_embedding: list[float], k: int = 5 ) -> Dict[str, Any]: """Search for similar vectors in Milvus collection Args: collection: Collection name query_embedding: Query vector k: Number of top results to return Returns: Dictionary with search results in Chroma-compatible format """ await self.get_or_create_collection(collection) # Perform search search_params = { "metric_type": "COSINE", "params": {} } results = await asyncio.to_thread( self.client.search, collection_name=collection, data=[query_embedding], limit=k, search_params=search_params, output_fields=["text", "file_id", "chunk_uuid"] ) # Convert results to Chroma-compatible format # Milvus returns: [[ {id, distance, entity: {...}} ]] ids = [] distances = [] metadatas = [] if results and len(results) > 0: for hit in results[0]: ids.append(hit.get("id", "")) distances.append(hit.get("distance", 0.0)) # Build metadata from entity fields entity = hit.get("entity", {}) metadata = {} if "text" in entity: metadata["text"] = entity["text"] if "file_id" in entity: metadata["file_id"] = entity["file_id"] if "chunk_uuid" in entity: metadata["uuid"] = entity["chunk_uuid"] metadatas.append(metadata) # Return in Chroma-compatible format (nested lists) result = { "ids": [ids], "distances": [distances], "metadatas": [metadatas] } self.ap.logger.info( f"Milvus search in '{collection}' returned {len(ids)} results" ) return result async def delete_by_file_id(self, collection: str, file_id: str) -> None: """Delete vectors from collection by file_id Args: collection: Collection name file_id: File ID to filter deletion """ await self.get_or_create_collection(collection) # Delete entities matching the file_id await asyncio.to_thread( self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"' ) self.ap.logger.info( f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}" ) async def delete_collection(self, collection: str): """Delete a Milvus collection Args: collection: Collection name to delete """ if collection in self._collections: del self._collections[collection] # Check if collection exists before attempting deletion has_collection = await asyncio.to_thread( self.client.has_collection, collection_name=collection ) if has_collection: await asyncio.to_thread( self.client.drop_collection, collection_name=collection ) self.ap.logger.info(f"Deleted Milvus collection '{collection}'") else: self.ap.logger.warning(f"Milvus collection '{collection}' not found")