feat(rag): make embedding and retrieving available

This commit is contained in:
Junyan Qin
2025-07-16 21:17:18 +08:00
parent f731115805
commit 2f2db4d445
20 changed files with 180 additions and 368 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import abc
from typing import Any, List, Dict
from typing import Any, Dict
import numpy as np
@@ -9,10 +9,10 @@ class VectorDatabase(abc.ABC):
def add_embeddings(
self,
collection: str,
ids: List[str],
embeddings: np.ndarray,
metadatas: List[Dict[str, Any]],
documents: List[str],
ids: list[str],
embeddings_list: list[list[float]],
metadatas: list[dict[str, Any]],
documents: list[str],
) -> None:
"""向指定 collection 添加向量数据。"""
pass

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import numpy as np
from typing import Any, List, Dict
import chromadb
from typing import Any
from chromadb import PersistentClient
from pkg.vector.vdb import VectorDatabase
from pkg.core import app
@@ -12,7 +12,7 @@ class ChromaVectorDatabase(VectorDatabase):
self.client = PersistentClient(path=base_path)
self._collections = {}
def get_or_create_collection(self, collection: str):
def get_or_create_collection(self, collection: str) -> chromadb.Collection:
if collection not in self._collections:
self._collections[collection] = self.client.get_or_create_collection(name=collection)
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
@@ -21,26 +21,25 @@ class ChromaVectorDatabase(VectorDatabase):
def add_embeddings(
self,
collection: str,
ids: List[str],
embeddings: np.ndarray,
metadatas: List[Dict[str, Any]],
documents: List[str],
ids: list[str],
embeddings_list: list[list[float]],
metadatas: list[dict[str, Any]],
) -> None:
col = self.get_or_create_collection(collection)
col.add(embeddings=embeddings.tolist(), ids=ids, metadatas=metadatas, documents=documents)
col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
col = self.get_or_create_collection(collection)
results = col.query(
query_embeddings=query_embedding.tolist(),
query_embeddings=query_embedding,
n_results=k,
include=['metadatas', 'distances', 'documents'],
)
self.ap.logger.debug(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
return results
def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None:
def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None:
col = self.get_or_create_collection(collection)
col.delete(where=where)
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")