mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-08 14:56:03 +00:00
feat(rag): all APIs ok
This commit is contained in:
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
class VectorDatabase(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def add_embeddings(
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
@@ -18,16 +18,20 @@ class VectorDatabase(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
|
||||
"""在指定 collection 中检索最相似的向量。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None:
|
||||
"""根据元数据删除指定 collection 中的向量。"""
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
"""根据 file_id 删除指定 collection 中的向量。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_or_create_collection(self, collection: str):
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
"""获取或创建 collection。"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_collection(self, collection: str):
|
||||
pass
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import chromadb
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from chromadb import PersistentClient
|
||||
from pkg.vector.vdb import VectorDatabase
|
||||
from pkg.core import app
|
||||
import chromadb
|
||||
|
||||
|
||||
class ChromaVectorDatabase(VectorDatabase):
|
||||
@@ -12,26 +13,29 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
self.client = PersistentClient(path=base_path)
|
||||
self._collections = {}
|
||||
|
||||
def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
||||
async def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
||||
if collection not in self._collections:
|
||||
self._collections[collection] = self.client.get_or_create_collection(name=collection)
|
||||
self._collections[collection] = await asyncio.to_thread(
|
||||
self.client.get_or_create_collection, name=collection
|
||||
)
|
||||
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
|
||||
return self._collections[collection]
|
||||
|
||||
def add_embeddings(
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
) -> None:
|
||||
col = self.get_or_create_collection(collection)
|
||||
col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||
col = await self.get_or_create_collection(collection)
|
||||
await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
|
||||
|
||||
def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
||||
col = self.get_or_create_collection(collection)
|
||||
results = col.query(
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
results = await asyncio.to_thread(
|
||||
col.query,
|
||||
query_embeddings=query_embedding,
|
||||
n_results=k,
|
||||
include=['metadatas', 'distances', 'documents'],
|
||||
@@ -39,7 +43,13 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
|
||||
return results
|
||||
|
||||
def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None:
|
||||
col = self.get_or_create_collection(collection)
|
||||
col.delete(where=where)
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
col = await self.get_or_create_collection(collection)
|
||||
await asyncio.to_thread(col.delete, where={'file_id': file_id})
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
if collection in self._collections:
|
||||
del self._collections[collection]
|
||||
await asyncio.to_thread(self.client.delete_collection, name=collection)
|
||||
self.ap.logger.info(f"Chroma collection '{collection}' deleted.")
|
||||
|
||||
Reference in New Issue
Block a user