Files
LangBot/src/langbot/pkg/vector/vdbs/milvus.py
Junyan Qin (Chin) fc6e414be4 feat: add GitHub Actions workflow for linting with Ruff (#1929)
* feat: add GitHub Actions workflow for linting with Ruff

* refactor: rename lint job and add formatting step to Ruff workflow

* chore: run ruff format

* chore: rename Ruff lint job to 'Lint' and add frontend linting workflow
2026-01-23 13:43:12 +08:00

289 lines
11 KiB
Python

from __future__ import annotations
import asyncio
from typing import Any, Dict
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
from pymilvus.milvus_client.index import IndexParams
from langbot.pkg.vector.vdb import VectorDatabase
from langbot.pkg.core import app
class MilvusVectorDatabase(VectorDatabase):
"""Milvus vector database implementation"""
def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None):
"""Initialize Milvus vector database
Args:
ap: Application instance
uri: Milvus connection URI. For local file: "milvus.db"
For remote server: "http://localhost:19530"
token: Optional authentication token for remote connections
"""
self.ap = ap
self.uri = uri
self.token = token
self.db_name = db_name
self.client = None
self._collections: set[str] = set()
self._initialize_client()
def _initialize_client(self):
"""Initialize Milvus client connection"""
try:
if self.token:
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
else:
self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
self.ap.logger.info(f'Connected to Milvus at {self.uri}')
except Exception as e:
self.ap.logger.error(f'Failed to connect to Milvus: {e}')
raise
@staticmethod
def _normalize_collection_name(collection: str) -> str:
"""Normalize collection name to comply with Milvus naming requirements.
Milvus requirements:
- First character must be an underscore or letter
- Can only contain numbers, letters and underscores
Args:
collection: Original collection name (e.g., UUID with hyphens)
Returns:
Normalized collection name that complies with Milvus requirements
"""
# Replace hyphens with underscores
normalized = collection.replace('-', '_')
# If first character is not a letter or underscore, prepend 'kb_'
if normalized and not (normalized[0].isalpha() or normalized[0] == '_'):
normalized = 'kb_' + normalized
return normalized
async def _ensure_vector_index(self, collection: str) -> None:
"""Ensure the vector field has an index.
Args:
collection: Normalized collection name
"""
index_params = IndexParams()
index_params.add_index(
field_name='vector',
index_type='AUTOINDEX',
metric_type='COSINE',
)
await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params)
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
"""Internal method to get or create a Milvus collection with proper configuration.
Args:
collection: Collection name (corresponds to knowledge base UUID)
vector_size: Dimension of the vectors (if None, defaults to 1536)
"""
# Normalize collection name for Milvus compatibility
collection = self._normalize_collection_name(collection)
if collection in self._collections:
return collection
# Check if collection exists
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
if not has_collection:
# Default dimension if not specified (for backward compatibility)
if vector_size is None:
vector_size = 1536
fields = [
FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255),
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255),
FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255),
]
schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors')
await asyncio.to_thread(
self.client.create_collection,
collection_name=collection,
schema=schema,
metric_type='COSINE',
)
await self._ensure_vector_index(collection)
self.ap.logger.info(
f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX"
)
else:
# Ensure index exists for existing collection
await self._ensure_index_if_missing(collection)
self.ap.logger.info(f"Milvus collection '{collection}' already exists")
self._collections.add(collection)
return collection
async def _ensure_index_if_missing(self, collection: str) -> None:
"""Check if index exists for collection and create if missing.
Args:
collection: Normalized collection name
"""
try:
indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection)
if 'vector' not in indexes:
await self._ensure_vector_index(collection)
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
except Exception as e:
self.ap.logger.warning(f"Could not verify/create index for collection '{collection}': {e}")
async def get_or_create_collection(self, collection: str):
"""Get or create a Milvus collection (without vector size - will use default).
Args:
collection: Collection name (corresponds to knowledge base UUID)
"""
collection = self._normalize_collection_name(collection)
return await self._get_or_create_collection_internal(collection)
async def add_embeddings(
self,
collection: str,
ids: list[str],
embeddings_list: list[list[float]],
metadatas: list[dict[str, Any]],
) -> None:
"""Add vector embeddings to Milvus collection
Args:
collection: Collection name
ids: List of unique IDs for each vector
embeddings_list: List of embedding vectors
metadatas: List of metadata dictionaries for each vector
"""
collection = self._normalize_collection_name(collection)
if not embeddings_list:
return
# Ensure collection exists with correct dimension
vector_size = len(embeddings_list[0])
await self._get_or_create_collection_internal(collection, vector_size)
# Prepare data in Milvus format
data = []
for i, vector_id in enumerate(ids):
entry = {
'id': vector_id,
'vector': embeddings_list[i],
}
# Add metadata fields
if metadatas and i < len(metadatas):
metadata = metadatas[i]
# Add common metadata fields
if 'text' in metadata:
entry['text'] = metadata['text']
if 'file_id' in metadata:
entry['file_id'] = metadata['file_id']
if 'uuid' in metadata:
entry['chunk_uuid'] = metadata['uuid']
data.append(entry)
# Insert data into Milvus
await asyncio.to_thread(self.client.insert, collection_name=collection, data=data)
# Load collection for searching (Milvus requires this)
await asyncio.to_thread(self.client.load_collection, collection_name=collection)
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
"""Search for similar vectors in Milvus collection
Args:
collection: Collection name
query_embedding: Query vector
k: Number of top results to return
Returns:
Dictionary with search results in Chroma-compatible format
"""
collection = self._normalize_collection_name(collection)
await self.get_or_create_collection(collection)
# Perform search
search_params = {'metric_type': 'COSINE', 'params': {}}
results = await asyncio.to_thread(
self.client.search,
collection_name=collection,
data=[query_embedding],
limit=k,
search_params=search_params,
output_fields=['text', 'file_id', 'chunk_uuid'],
)
# Convert results to Chroma-compatible format
# Milvus returns: [[ {id, distance, entity: {...}} ]]
ids = []
distances = []
metadatas = []
if results and len(results) > 0:
for hit in results[0]:
ids.append(hit.get('id', ''))
distances.append(hit.get('distance', 0.0))
# Build metadata from entity fields
entity = hit.get('entity', {})
metadata = {}
if 'text' in entity:
metadata['text'] = entity['text']
if 'file_id' in entity:
metadata['file_id'] = entity['file_id']
if 'chunk_uuid' in entity:
metadata['uuid'] = entity['chunk_uuid']
metadatas.append(metadata)
# Return in Chroma-compatible format (nested lists)
result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results")
return result
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
"""Delete vectors from collection by file_id
Args:
collection: Collection name
file_id: File ID to filter deletion
"""
collection = self._normalize_collection_name(collection)
await self.get_or_create_collection(collection)
# Delete entities matching the file_id
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
async def delete_collection(self, collection: str):
"""Delete a Milvus collection
Args:
collection: Collection name to delete
"""
collection = self._normalize_collection_name(collection)
self._collections.discard(collection)
# Check if collection exists before attempting deletion
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
if has_collection:
await asyncio.to_thread(self.client.drop_collection, collection_name=collection)
self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
else:
self.ap.logger.warning(f"Milvus collection '{collection}' not found")