mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
feat(milvus): milvus related updates (#1908)
- Add Milvus db_name configuration and client parameter support. - change kb_data uuid for Milvus. 3. add MAX_BATCH_SIZE for openai. - support more vector_size.
This commit is contained in:
@@ -32,12 +32,18 @@ class Embedder(BaseService):
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
|
||||
|
||||
# get embeddings
|
||||
embeddings_list: list[list[float]] = await embedding_model.provider.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=chunks,
|
||||
extra_args={}, # TODO: add extra args
|
||||
)
|
||||
# get embeddings (batch size limit: 64 for OpenAI)
|
||||
MAX_BATCH_SIZE = 64
|
||||
embeddings_list: list[list[float]] = []
|
||||
|
||||
for i in range(0, len(chunks), MAX_BATCH_SIZE):
|
||||
batch = chunks[i:i + MAX_BATCH_SIZE]
|
||||
batch_embeddings = await embedding_model.provider.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=batch,
|
||||
extra_args={}, # TODO: add extra args
|
||||
)
|
||||
embeddings_list.extend(batch_embeddings)
|
||||
|
||||
# save embeddings to vdb
|
||||
await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts)
|
||||
|
||||
@@ -37,7 +37,8 @@ class VectorDBManager:
|
||||
milvus_config = kb_config.get('milvus', {})
|
||||
uri = milvus_config.get('uri', './data/milvus.db')
|
||||
token = milvus_config.get('token')
|
||||
self.vector_db = MilvusVectorDatabase(self.ap, uri=uri, token=token)
|
||||
db_name = milvus_config.get('db_name', 'default')
|
||||
self.vector_db = MilvusVectorDatabase(self.ap, uri=uri, token=token, db_name=db_name)
|
||||
self.ap.logger.info('Initialized Milvus vector database backend.')
|
||||
|
||||
elif vdb_type == 'pgvector':
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from pymilvus import MilvusClient, DataType
|
||||
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
|
||||
from pymilvus.milvus_client.index import IndexParams
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.core import app
|
||||
|
||||
@@ -9,7 +10,7 @@ from langbot.pkg.core import app
|
||||
class MilvusVectorDatabase(VectorDatabase):
|
||||
"""Milvus vector database implementation"""
|
||||
|
||||
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None):
|
||||
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None):
|
||||
"""Initialize Milvus vector database
|
||||
|
||||
Args:
|
||||
@@ -21,30 +22,76 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
self.ap = ap
|
||||
self.uri = uri
|
||||
self.token = token
|
||||
self.db_name = db_name
|
||||
self.client = None
|
||||
self._collections = {}
|
||||
self._collections: set[str] = set()
|
||||
self._initialize_client()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize Milvus client connection"""
|
||||
try:
|
||||
if self.token:
|
||||
self.client = MilvusClient(uri=self.uri, token=self.token)
|
||||
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
|
||||
else:
|
||||
self.client = MilvusClient(uri=self.uri)
|
||||
self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
|
||||
self.ap.logger.info(f"Connected to Milvus at {self.uri}")
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to Milvus: {e}")
|
||||
raise
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create a Milvus collection
|
||||
@staticmethod
|
||||
def _normalize_collection_name(collection: str) -> str:
|
||||
"""Normalize collection name to comply with Milvus naming requirements.
|
||||
|
||||
Milvus requirements:
|
||||
- First character must be an underscore or letter
|
||||
- Can only contain numbers, letters and underscores
|
||||
|
||||
Args:
|
||||
collection: Original collection name (e.g., UUID with hyphens)
|
||||
|
||||
Returns:
|
||||
Normalized collection name that complies with Milvus requirements
|
||||
"""
|
||||
# Replace hyphens with underscores
|
||||
normalized = collection.replace('-', '_')
|
||||
|
||||
# If first character is not a letter or underscore, prepend 'kb_'
|
||||
if normalized and not (normalized[0].isalpha() or normalized[0] == '_'):
|
||||
normalized = 'kb_' + normalized
|
||||
|
||||
return normalized
|
||||
|
||||
async def _ensure_vector_index(self, collection: str) -> None:
|
||||
"""Ensure the vector field has an index.
|
||||
|
||||
Args:
|
||||
collection: Normalized collection name
|
||||
"""
|
||||
index_params = IndexParams()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="AUTOINDEX",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.client.create_index,
|
||||
collection_name=collection,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
|
||||
"""Internal method to get or create a Milvus collection with proper configuration.
|
||||
|
||||
Args:
|
||||
collection: Collection name (corresponds to knowledge base UUID)
|
||||
vector_size: Dimension of the vectors (if None, defaults to 1536)
|
||||
"""
|
||||
# Normalize collection name for Milvus compatibility
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
if collection in self._collections:
|
||||
return self._collections[collection]
|
||||
return collection
|
||||
|
||||
# Check if collection exists
|
||||
has_collection = await asyncio.to_thread(
|
||||
@@ -52,12 +99,13 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
)
|
||||
|
||||
if not has_collection:
|
||||
# Create collection with custom schema to support string IDs
|
||||
from pymilvus import CollectionSchema, FieldSchema, DataType
|
||||
# Default dimension if not specified (for backward compatibility)
|
||||
if vector_size is None:
|
||||
vector_size = 1536
|
||||
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1536),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255),
|
||||
@@ -72,26 +120,42 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
# Create index for vector field (required for loading/searching)
|
||||
index_params = {
|
||||
"metric_type": "COSINE",
|
||||
"index_type": "AUTOINDEX",
|
||||
"params": {}
|
||||
}
|
||||
await asyncio.to_thread(
|
||||
self.client.create_index,
|
||||
collection_name=collection,
|
||||
field_name="vector",
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.ap.logger.info(f"Created Milvus collection '{collection}' with index")
|
||||
await self._ensure_vector_index(collection)
|
||||
self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX")
|
||||
else:
|
||||
# Ensure index exists for existing collection
|
||||
await self._ensure_index_if_missing(collection)
|
||||
self.ap.logger.info(f"Milvus collection '{collection}' already exists")
|
||||
|
||||
self._collections[collection] = collection
|
||||
self._collections.add(collection)
|
||||
return collection
|
||||
|
||||
async def _ensure_index_if_missing(self, collection: str) -> None:
|
||||
"""Check if index exists for collection and create if missing.
|
||||
|
||||
Args:
|
||||
collection: Normalized collection name
|
||||
"""
|
||||
try:
|
||||
indexes = await asyncio.to_thread(
|
||||
self.client.list_indexes,
|
||||
collection_name=collection
|
||||
)
|
||||
if "vector" not in indexes:
|
||||
await self._ensure_vector_index(collection)
|
||||
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f"Could not verify/create index for collection '{collection}': {e}")
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""Get or create a Milvus collection (without vector size - will use default).
|
||||
|
||||
Args:
|
||||
collection: Collection name (corresponds to knowledge base UUID)
|
||||
"""
|
||||
collection = self._normalize_collection_name(collection)
|
||||
return await self._get_or_create_collection_internal(collection)
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
@@ -107,7 +171,14 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
embeddings_list: List of embedding vectors
|
||||
metadatas: List of metadata dictionaries for each vector
|
||||
"""
|
||||
await self.get_or_create_collection(collection)
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
if not embeddings_list:
|
||||
return
|
||||
|
||||
# Ensure collection exists with correct dimension
|
||||
vector_size = len(embeddings_list[0])
|
||||
await self._get_or_create_collection_internal(collection, vector_size)
|
||||
|
||||
# Prepare data in Milvus format
|
||||
data = []
|
||||
@@ -156,6 +227,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
Returns:
|
||||
Dictionary with search results in Chroma-compatible format
|
||||
"""
|
||||
collection = self._normalize_collection_name(collection)
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Perform search
|
||||
@@ -214,6 +286,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
collection: Collection name
|
||||
file_id: File ID to filter deletion
|
||||
"""
|
||||
collection = self._normalize_collection_name(collection)
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Delete entities matching the file_id
|
||||
@@ -232,8 +305,9 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
Args:
|
||||
collection: Collection name to delete
|
||||
"""
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
self._collections.discard(collection)
|
||||
|
||||
# Check if collection exists before attempting deletion
|
||||
has_collection = await asyncio.to_thread(
|
||||
|
||||
@@ -51,6 +51,7 @@ vdb:
|
||||
milvus:
|
||||
uri: 'http://127.0.0.1:19530'
|
||||
token: ''
|
||||
db_name: ''
|
||||
pgvector:
|
||||
host: '127.0.0.1'
|
||||
port: 5433
|
||||
|
||||
Reference in New Issue
Block a user