From 199164fc4b87b645fe0b2479d5eb5ea2949dd557 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Sun, 13 Jul 2025 23:12:08 +0800 Subject: [PATCH] fix: delete embedding models file --- .../knowledge/services/embedding_models.py | 238 ------------------ 1 file changed, 238 deletions(-) delete mode 100644 pkg/rag/knowledge/services/embedding_models.py diff --git a/pkg/rag/knowledge/services/embedding_models.py b/pkg/rag/knowledge/services/embedding_models.py deleted file mode 100644 index 7301d640..00000000 --- a/pkg/rag/knowledge/services/embedding_models.py +++ /dev/null @@ -1,238 +0,0 @@ -# services/embedding_models.py - -import os -from typing import Dict, Any, List -import logging -import aiohttp # Import aiohttp for asynchronous requests -import asyncio -from sentence_transformers import SentenceTransformer - -logger = logging.getLogger(__name__) - - -# Base class for all embedding models -class BaseEmbeddingModel: - def __init__(self, model_name: str): - self.model_name = model_name - self._embedding_dimension = None - - async def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Asynchronously embeds a list of texts.""" - raise NotImplementedError - - async def embed_query(self, text: str) -> List[float]: - """Asynchronously embeds a single query text.""" - raise NotImplementedError - - @property - def embedding_dimension(self) -> int: - """Returns the embedding dimension of the model.""" - if self._embedding_dimension is None: - raise NotImplementedError('Embedding dimension not set for this model.') - return self._embedding_dimension - - -class EmbeddingModelFactory: - @staticmethod - def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel: - """ - Factory method to create an embedding model instance. - Currently only supports 'third_party_api' types. - """ - if model_name_key not in EMBEDDING_MODEL_CONFIGS: - raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.") - - config = EMBEDDING_MODEL_CONFIGS[model_name_key] - - if config['type'] == 'third_party_api': - required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension'] - if not all(key in config for key in required_keys): - raise ValueError( - f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}" - ) - - # Retrieve model_name from config if it differs from model_name_key - # Some APIs expect a specific 'model' value in the payload that might be different from the key - api_model_name = config.get('model_name', model_name_key) - - return ThirdPartyAPIEmbeddingModel( - model_name=api_model_name, # Use the model_name from config or the key - api_endpoint=config['api_endpoint'], - headers=config['headers'], - payload_template=config['payload_template'], - embedding_dimension=config['embedding_dimension'], - ) - - -class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): - def __init__(self, model_name: str): - super().__init__(model_name) - try: - # SentenceTransformer is inherently synchronous, but we'll wrap its calls - # in async methods. The actual computation will still block the event loop - # if not run in a separate thread/process, but this keeps the API consistent. - self.model = SentenceTransformer(model_name) - self._embedding_dimension = self.model.get_sentence_embedding_dimension() - logger.info( - f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}" - ) - except Exception as e: - logger.error(f'Failed to load SentenceTransformer model {model_name}: {e}') - raise - - async def embed_documents(self, texts: List[str]) -> List[List[float]]: - # For CPU-bound tasks like local model inference, consider running in a thread pool - # to prevent blocking the event loop for long operations. - # For simplicity here, we'll call it directly. - return self.model.encode(texts).tolist() - - async def embed_query(self, text: str) -> List[float]: - return self.model.encode(text).tolist() - - -class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): - def __init__( - self, - model_name: str, - api_endpoint: str, - headers: Dict[str, str], - payload_template: Dict[str, Any], - embedding_dimension: int, - ): - super().__init__(model_name) - self.api_endpoint = api_endpoint - self.headers = headers - self.payload_template = payload_template - self._embedding_dimension = embedding_dimension - self.session = None # aiohttp client session will be initialized on first use or in a context manager - logger.info( - f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}" - ) - - async def _get_session(self): - """Lazily create or return the aiohttp client session.""" - if self.session is None or self.session.closed: - self.session = aiohttp.ClientSession() - return self.session - - async def close_session(self): - """Explicitly close the aiohttp client session.""" - if self.session and not self.session.closed: - await self.session.close() - self.session = None - logger.info(f'Closed aiohttp session for model {self.model_name}') - - async def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Asynchronously embeds a list of texts using the third-party API.""" - session = await self._get_session() - embeddings = [] - tasks = [] - for text in texts: - payload = self.payload_template.copy() - if 'input' in payload: - payload['input'] = text - elif 'texts' in payload: - payload['texts'] = [text] - else: - raise ValueError('Payload template does not contain expected text input key.') - - tasks.append(self._make_api_request(session, payload)) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - for i, res in enumerate(results): - if isinstance(res, Exception): - logger.error(f"Error embedding text '{texts[i][:50]}...': {res}") - # Depending on your error handling strategy, you might: - # - Append None or an empty list - # - Re-raise the exception to stop processing - # - Log and skip, then continue - embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure - else: - embeddings.append(res) - - return embeddings - - async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]: - """Helper to make an asynchronous API request and extract embedding.""" - try: - async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response: - response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx) - api_response = await response.json() - - # Adjust this based on your API's actual response structure - if 'data' in api_response and len(api_response['data']) > 0 and 'embedding' in api_response['data'][0]: - embedding = api_response['data'][0]['embedding'] - if len(embedding) != self.embedding_dimension: - logger.warning( - f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.' - ) - return embedding - elif ( - 'embeddings' in api_response - and isinstance(api_response['embeddings'], list) - and api_response['embeddings'] - ): - embedding = api_response['embeddings'][0] - if len(embedding) != self.embedding_dimension: - logger.warning( - f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.' - ) - return embedding - else: - raise ValueError(f'Unexpected API response structure: {api_response}') - - except aiohttp.ClientError as e: - raise ConnectionError(f'API request failed: {e}') from e - except ValueError as e: - raise ValueError(f'Error processing API response: {e}') from e - - async def embed_query(self, text: str) -> List[float]: - """Asynchronously embeds a single query text.""" - results = await self.embed_documents([text]) - if results: - return results[0] - return [] # Or raise an error if embedding a query must always succeed - - -# --- Embedding Model Configuration --- -EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { - 'MiniLM': { # Example for a local Sentence Transformer model - 'type': 'sentence_transformer', - 'model_name': 'sentence-transformers/all-MiniLM-L6-v2', - }, - 'bge-m3': { # Example for a third-party API model - 'type': 'third_party_api', - 'model_name': 'bge-m3', - 'api_endpoint': 'https://api.qhaigc.net/v1/embeddings', - 'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("rag_api_key")}'}, - 'payload_template': {'model': 'bge-m3', 'input': ''}, - 'embedding_dimension': 1024, - }, - 'OpenAI-Ada-002': { - 'type': 'third_party_api', - 'model_name': 'text-embedding-ada-002', - 'api_endpoint': 'https://api.openai.com/v1/embeddings', - 'headers': { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}', # Ensure OPENAI_API_KEY is set - }, - 'payload_template': { - 'model': 'text-embedding-ada-002', - 'input': '', # Text will be injected here - }, - 'embedding_dimension': 1536, - }, - 'OpenAI-Embedding-3-Small': { - 'type': 'third_party_api', - 'model_name': 'text-embedding-3-small', - 'api_endpoint': 'https://api.openai.com/v1/embeddings', - 'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}'}, - 'payload_template': { - 'model': 'text-embedding-3-small', - 'input': '', - # "dimensions": 512 # Optional: uncomment if you want a specific output dimension - }, - 'embedding_dimension': 1536, # Default max dimension for text-embedding-3-small - }, -}