""" Fake provider factory for tests. Provides a deterministic fake provider that simulates LLM responses without real API calls. """ from __future__ import annotations from unittest.mock import Mock import typing import langbot_plugin.api.entities.builtin.provider.message as provider_message class FakeProvider: """Deterministic fake provider for unit and integration tests. Simulates various provider behaviors: - Normal text response - Streaming response - Timeout error - Auth error - Rate-limit error - Malformed response Does not call real LLM vendors. Does not require API keys. """ PONG_RESPONSE = 'LANGBOT_FAKE_PONG' def __init__( self, *, default_response: str = 'fake response', streaming_chunks: list[str] = None, raise_error: Exception = None, captured_requests: list = None, ): self._default_response = default_response self._streaming_chunks = streaming_chunks or ['fake ', 'response'] self._raise_error = raise_error self._captured_requests = captured_requests if captured_requests is not None else [] def returns(self, text: str) -> 'FakeProvider': """Configure provider to return a specific text response.""" self._default_response = text self._streaming_chunks = [text] return self def returns_streaming(self, chunks: list[str]) -> 'FakeProvider': """Configure provider to return streaming chunks.""" self._streaming_chunks = chunks self._default_response = ''.join(chunks) return self def raises(self, error: Exception) -> 'FakeProvider': """Configure provider to raise an error.""" self._raise_error = error return self def timeout(self) -> 'FakeProvider': """Configure provider to simulate timeout.""" return self.raises(TimeoutError('Provider timeout')) def auth_error(self) -> 'FakeProvider': """Configure provider to simulate auth error.""" return self.raises(Exception('Invalid API key')) def rate_limit(self) -> 'FakeProvider': """Configure provider to simulate rate limit.""" return self.raises(Exception('Rate limit exceeded')) def malformed(self) -> 'FakeProvider': """Configure provider to simulate malformed response.""" self._default_response = None return self def get_captured_requests(self) -> list: """Get all captured request arguments for assertions.""" return self._captured_requests.copy() def clear_captured_requests(self): """Clear captured requests.""" self._captured_requests.clear() def _create_message(self, content: str) -> provider_message.Message: """Create a provider message from text content.""" return provider_message.Message( role='assistant', content=content, ) def _create_chunk( self, content: str, is_final: bool = False, msg_sequence: int = 0, ) -> provider_message.MessageChunk: """Create a provider message chunk.""" return provider_message.MessageChunk( role='assistant', content=content, is_final=is_final, msg_sequence=msg_sequence, ) async def invoke_llm( self, query, model, messages: list, funcs: list, extra_args: dict, remove_think: bool = False, ) -> provider_message.Message: """Simulate non-streaming LLM invocation.""" # Capture request for assertions self._captured_requests.append( { 'query_id': query.query_id if query else None, 'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None, 'messages': messages, 'funcs': funcs, 'extra_args': extra_args, } ) # Simulate error if configured if self._raise_error: raise self._raise_error # Return response if self._default_response is None: # Malformed response return provider_message.Message(role='assistant', content=None) return self._create_message(self._default_response) async def invoke_llm_stream( self, query, model, messages: list, funcs: list, extra_args: dict, remove_think: bool = False, ) -> typing.AsyncGenerator[provider_message.MessageChunk, None]: """Simulate streaming LLM invocation.""" # Capture request for assertions self._captured_requests.append( { 'query_id': query.query_id if query else None, 'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None, 'messages': messages, 'funcs': funcs, 'extra_args': extra_args, 'streaming': True, } ) # Simulate error if configured if self._raise_error: raise self._raise_error # Yield chunks for i, chunk in enumerate(self._streaming_chunks): is_final = i == len(self._streaming_chunks) - 1 yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i) def fake_provider( default_response: str = 'fake response', ) -> FakeProvider: """Create a FakeProvider with optional default response.""" return FakeProvider(default_response=default_response) def fake_provider_pong() -> FakeProvider: """Create a FakeProvider that returns the pong response.""" return FakeProvider(default_response=FakeProvider.PONG_RESPONSE) def fake_provider_timeout() -> FakeProvider: """Create a FakeProvider that simulates timeout.""" return FakeProvider().timeout() def fake_provider_auth_error() -> FakeProvider: """Create a FakeProvider that simulates auth error.""" return FakeProvider().auth_error() def fake_provider_rate_limit() -> FakeProvider: """Create a FakeProvider that simulates rate limit.""" return FakeProvider().rate_limit() def fake_provider_malformed() -> FakeProvider: """Create a FakeProvider that simulates malformed response.""" return FakeProvider().malformed() # ============== Mock Model Factory ============== def fake_model( *, uuid: str = 'test-model-uuid', name: str = 'test-model', abilities: list[str] = None, provider: FakeProvider = None, ) -> Mock: """Create a mock model with a fake provider.""" model = Mock() model.model_entity = Mock() model.model_entity.uuid = uuid model.model_entity.name = name model.model_entity.abilities = abilities or ['func_call', 'vision'] model.model_entity.extra_args = {} # Attach fake provider if provider is None: provider = FakeProvider() model.provider = provider return model