diff --git a/src/langbot/pkg/rag/knowledge/services/embedder.py b/src/langbot/pkg/rag/knowledge/services/embedder.py index a067c90c..f93382ff 100644 --- a/src/langbot/pkg/rag/knowledge/services/embedder.py +++ b/src/langbot/pkg/rag/knowledge/services/embedder.py @@ -32,12 +32,18 @@ class Embedder(BaseService): await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts)) - # get embeddings - embeddings_list: list[list[float]] = await embedding_model.provider.requester.invoke_embedding( - model=embedding_model, - input_text=chunks, - extra_args={}, # TODO: add extra args - ) + # get embeddings (batch size limit: 64 for OpenAI) + MAX_BATCH_SIZE = 64 + embeddings_list: list[list[float]] = [] + + for i in range(0, len(chunks), MAX_BATCH_SIZE): + batch = chunks[i:i + MAX_BATCH_SIZE] + batch_embeddings = await embedding_model.provider.requester.invoke_embedding( + model=embedding_model, + input_text=batch, + extra_args={}, # TODO: add extra args + ) + embeddings_list.extend(batch_embeddings) # save embeddings to vdb await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts) diff --git a/src/langbot/pkg/vector/mgr.py b/src/langbot/pkg/vector/mgr.py index 52b8b479..f95f5f75 100644 --- a/src/langbot/pkg/vector/mgr.py +++ b/src/langbot/pkg/vector/mgr.py @@ -37,7 +37,8 @@ class VectorDBManager: milvus_config = kb_config.get('milvus', {}) uri = milvus_config.get('uri', './data/milvus.db') token = milvus_config.get('token') - self.vector_db = MilvusVectorDatabase(self.ap, uri=uri, token=token) + db_name = milvus_config.get('db_name', 'default') + self.vector_db = MilvusVectorDatabase(self.ap, uri=uri, token=token, db_name=db_name) self.ap.logger.info('Initialized Milvus vector database backend.') elif vdb_type == 'pgvector': diff --git a/src/langbot/pkg/vector/vdbs/milvus.py b/src/langbot/pkg/vector/vdbs/milvus.py index d9f822cd..f15071c4 100644 --- a/src/langbot/pkg/vector/vdbs/milvus.py +++ b/src/langbot/pkg/vector/vdbs/milvus.py @@ -1,7 +1,8 @@ from __future__ import annotations import asyncio from typing import Any, Dict -from pymilvus import MilvusClient, DataType +from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema +from pymilvus.milvus_client.index import IndexParams from langbot.pkg.vector.vdb import VectorDatabase from langbot.pkg.core import app @@ -9,7 +10,7 @@ from langbot.pkg.core import app class MilvusVectorDatabase(VectorDatabase): """Milvus vector database implementation""" - def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None): + def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None): """Initialize Milvus vector database Args: @@ -21,30 +22,76 @@ class MilvusVectorDatabase(VectorDatabase): self.ap = ap self.uri = uri self.token = token + self.db_name = db_name self.client = None - self._collections = {} + self._collections: set[str] = set() self._initialize_client() def _initialize_client(self): """Initialize Milvus client connection""" try: if self.token: - self.client = MilvusClient(uri=self.uri, token=self.token) + self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name) else: - self.client = MilvusClient(uri=self.uri) + self.client = MilvusClient(uri=self.uri, db_name=self.db_name) self.ap.logger.info(f"Connected to Milvus at {self.uri}") except Exception as e: self.ap.logger.error(f"Failed to connect to Milvus: {e}") raise - async def get_or_create_collection(self, collection: str): - """Get or create a Milvus collection + @staticmethod + def _normalize_collection_name(collection: str) -> str: + """Normalize collection name to comply with Milvus naming requirements. + + Milvus requirements: + - First character must be an underscore or letter + - Can only contain numbers, letters and underscores + + Args: + collection: Original collection name (e.g., UUID with hyphens) + + Returns: + Normalized collection name that complies with Milvus requirements + """ + # Replace hyphens with underscores + normalized = collection.replace('-', '_') + + # If first character is not a letter or underscore, prepend 'kb_' + if normalized and not (normalized[0].isalpha() or normalized[0] == '_'): + normalized = 'kb_' + normalized + + return normalized + + async def _ensure_vector_index(self, collection: str) -> None: + """Ensure the vector field has an index. + + Args: + collection: Normalized collection name + """ + index_params = IndexParams() + index_params.add_index( + field_name="vector", + index_type="AUTOINDEX", + metric_type="COSINE", + ) + await asyncio.to_thread( + self.client.create_index, + collection_name=collection, + index_params=index_params + ) + + async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None): + """Internal method to get or create a Milvus collection with proper configuration. Args: collection: Collection name (corresponds to knowledge base UUID) + vector_size: Dimension of the vectors (if None, defaults to 1536) """ + # Normalize collection name for Milvus compatibility + collection = self._normalize_collection_name(collection) + if collection in self._collections: - return self._collections[collection] + return collection # Check if collection exists has_collection = await asyncio.to_thread( @@ -52,12 +99,13 @@ class MilvusVectorDatabase(VectorDatabase): ) if not has_collection: - # Create collection with custom schema to support string IDs - from pymilvus import CollectionSchema, FieldSchema, DataType + # Default dimension if not specified (for backward compatibility) + if vector_size is None: + vector_size = 1536 fields = [ FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1536), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255), FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255), @@ -72,26 +120,42 @@ class MilvusVectorDatabase(VectorDatabase): metric_type="COSINE", ) - # Create index for vector field (required for loading/searching) - index_params = { - "metric_type": "COSINE", - "index_type": "AUTOINDEX", - "params": {} - } - await asyncio.to_thread( - self.client.create_index, - collection_name=collection, - field_name="vector", - index_params=index_params - ) - - self.ap.logger.info(f"Created Milvus collection '{collection}' with index") + await self._ensure_vector_index(collection) + self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX") else: + # Ensure index exists for existing collection + await self._ensure_index_if_missing(collection) self.ap.logger.info(f"Milvus collection '{collection}' already exists") - self._collections[collection] = collection + self._collections.add(collection) return collection + async def _ensure_index_if_missing(self, collection: str) -> None: + """Check if index exists for collection and create if missing. + + Args: + collection: Normalized collection name + """ + try: + indexes = await asyncio.to_thread( + self.client.list_indexes, + collection_name=collection + ) + if "vector" not in indexes: + await self._ensure_vector_index(collection) + self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'") + except Exception as e: + self.ap.logger.warning(f"Could not verify/create index for collection '{collection}': {e}") + + async def get_or_create_collection(self, collection: str): + """Get or create a Milvus collection (without vector size - will use default). + + Args: + collection: Collection name (corresponds to knowledge base UUID) + """ + collection = self._normalize_collection_name(collection) + return await self._get_or_create_collection_internal(collection) + async def add_embeddings( self, collection: str, @@ -107,7 +171,14 @@ class MilvusVectorDatabase(VectorDatabase): embeddings_list: List of embedding vectors metadatas: List of metadata dictionaries for each vector """ - await self.get_or_create_collection(collection) + collection = self._normalize_collection_name(collection) + + if not embeddings_list: + return + + # Ensure collection exists with correct dimension + vector_size = len(embeddings_list[0]) + await self._get_or_create_collection_internal(collection, vector_size) # Prepare data in Milvus format data = [] @@ -156,6 +227,7 @@ class MilvusVectorDatabase(VectorDatabase): Returns: Dictionary with search results in Chroma-compatible format """ + collection = self._normalize_collection_name(collection) await self.get_or_create_collection(collection) # Perform search @@ -214,6 +286,7 @@ class MilvusVectorDatabase(VectorDatabase): collection: Collection name file_id: File ID to filter deletion """ + collection = self._normalize_collection_name(collection) await self.get_or_create_collection(collection) # Delete entities matching the file_id @@ -232,8 +305,9 @@ class MilvusVectorDatabase(VectorDatabase): Args: collection: Collection name to delete """ - if collection in self._collections: - del self._collections[collection] + collection = self._normalize_collection_name(collection) + + self._collections.discard(collection) # Check if collection exists before attempting deletion has_collection = await asyncio.to_thread( diff --git a/src/langbot/templates/config.yaml b/src/langbot/templates/config.yaml index 590102ed..bd4bd180 100644 --- a/src/langbot/templates/config.yaml +++ b/src/langbot/templates/config.yaml @@ -51,6 +51,7 @@ vdb: milvus: uri: 'http://127.0.0.1:19530' token: '' + db_name: '' pgvector: host: '127.0.0.1' port: 5433