mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-16 10:46:03 +00:00
229 lines
6.9 KiB
Python
229 lines
6.9 KiB
Python
"""
|
|
Fake provider factory for tests.
|
|
|
|
Provides a deterministic fake provider that simulates LLM responses without real API calls.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import Mock
|
|
import typing
|
|
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
|
|
class FakeProvider:
|
|
"""Deterministic fake provider for unit and integration tests.
|
|
|
|
Simulates various provider behaviors:
|
|
- Normal text response
|
|
- Streaming response
|
|
- Timeout error
|
|
- Auth error
|
|
- Rate-limit error
|
|
- Malformed response
|
|
|
|
Does not call real LLM vendors.
|
|
Does not require API keys.
|
|
"""
|
|
|
|
PONG_RESPONSE = 'LANGBOT_FAKE_PONG'
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
default_response: str = 'fake response',
|
|
streaming_chunks: list[str] = None,
|
|
raise_error: Exception = None,
|
|
captured_requests: list = None,
|
|
):
|
|
self._default_response = default_response
|
|
self._streaming_chunks = streaming_chunks or ['fake ', 'response']
|
|
self._raise_error = raise_error
|
|
self._captured_requests = captured_requests if captured_requests is not None else []
|
|
|
|
def returns(self, text: str) -> 'FakeProvider':
|
|
"""Configure provider to return a specific text response."""
|
|
self._default_response = text
|
|
self._streaming_chunks = [text]
|
|
return self
|
|
|
|
def returns_streaming(self, chunks: list[str]) -> 'FakeProvider':
|
|
"""Configure provider to return streaming chunks."""
|
|
self._streaming_chunks = chunks
|
|
self._default_response = ''.join(chunks)
|
|
return self
|
|
|
|
def raises(self, error: Exception) -> 'FakeProvider':
|
|
"""Configure provider to raise an error."""
|
|
self._raise_error = error
|
|
return self
|
|
|
|
def timeout(self) -> 'FakeProvider':
|
|
"""Configure provider to simulate timeout."""
|
|
return self.raises(TimeoutError('Provider timeout'))
|
|
|
|
def auth_error(self) -> 'FakeProvider':
|
|
"""Configure provider to simulate auth error."""
|
|
return self.raises(Exception('Invalid API key'))
|
|
|
|
def rate_limit(self) -> 'FakeProvider':
|
|
"""Configure provider to simulate rate limit."""
|
|
return self.raises(Exception('Rate limit exceeded'))
|
|
|
|
def malformed(self) -> 'FakeProvider':
|
|
"""Configure provider to simulate malformed response."""
|
|
self._default_response = None
|
|
return self
|
|
|
|
def get_captured_requests(self) -> list:
|
|
"""Get all captured request arguments for assertions."""
|
|
return self._captured_requests.copy()
|
|
|
|
def clear_captured_requests(self):
|
|
"""Clear captured requests."""
|
|
self._captured_requests.clear()
|
|
|
|
def _create_message(self, content: str) -> provider_message.Message:
|
|
"""Create a provider message from text content."""
|
|
return provider_message.Message(
|
|
role='assistant',
|
|
content=content,
|
|
)
|
|
|
|
def _create_chunk(
|
|
self,
|
|
content: str,
|
|
is_final: bool = False,
|
|
msg_sequence: int = 0,
|
|
) -> provider_message.MessageChunk:
|
|
"""Create a provider message chunk."""
|
|
return provider_message.MessageChunk(
|
|
role='assistant',
|
|
content=content,
|
|
is_final=is_final,
|
|
msg_sequence=msg_sequence,
|
|
)
|
|
|
|
async def invoke_llm(
|
|
self,
|
|
query,
|
|
model,
|
|
messages: list,
|
|
funcs: list,
|
|
extra_args: dict,
|
|
remove_think: bool = False,
|
|
) -> provider_message.Message:
|
|
"""Simulate non-streaming LLM invocation."""
|
|
# Capture request for assertions
|
|
self._captured_requests.append(
|
|
{
|
|
'query_id': query.query_id if query else None,
|
|
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
|
'messages': messages,
|
|
'funcs': funcs,
|
|
'extra_args': extra_args,
|
|
}
|
|
)
|
|
|
|
# Simulate error if configured
|
|
if self._raise_error:
|
|
raise self._raise_error
|
|
|
|
# Return response
|
|
if self._default_response is None:
|
|
# Malformed response
|
|
return provider_message.Message(role='assistant', content=None)
|
|
|
|
return self._create_message(self._default_response)
|
|
|
|
async def invoke_llm_stream(
|
|
self,
|
|
query,
|
|
model,
|
|
messages: list,
|
|
funcs: list,
|
|
extra_args: dict,
|
|
remove_think: bool = False,
|
|
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
|
"""Simulate streaming LLM invocation."""
|
|
# Capture request for assertions
|
|
self._captured_requests.append(
|
|
{
|
|
'query_id': query.query_id if query else None,
|
|
'model': model.model_entity.name if model and hasattr(model, 'model_entity') else None,
|
|
'messages': messages,
|
|
'funcs': funcs,
|
|
'extra_args': extra_args,
|
|
'streaming': True,
|
|
}
|
|
)
|
|
|
|
# Simulate error if configured
|
|
if self._raise_error:
|
|
raise self._raise_error
|
|
|
|
# Yield chunks
|
|
for i, chunk in enumerate(self._streaming_chunks):
|
|
is_final = i == len(self._streaming_chunks) - 1
|
|
yield self._create_chunk(chunk, is_final=is_final, msg_sequence=i)
|
|
|
|
|
|
def fake_provider(
|
|
default_response: str = 'fake response',
|
|
) -> FakeProvider:
|
|
"""Create a FakeProvider with optional default response."""
|
|
return FakeProvider(default_response=default_response)
|
|
|
|
|
|
def fake_provider_pong() -> FakeProvider:
|
|
"""Create a FakeProvider that returns the pong response."""
|
|
return FakeProvider(default_response=FakeProvider.PONG_RESPONSE)
|
|
|
|
|
|
def fake_provider_timeout() -> FakeProvider:
|
|
"""Create a FakeProvider that simulates timeout."""
|
|
return FakeProvider().timeout()
|
|
|
|
|
|
def fake_provider_auth_error() -> FakeProvider:
|
|
"""Create a FakeProvider that simulates auth error."""
|
|
return FakeProvider().auth_error()
|
|
|
|
|
|
def fake_provider_rate_limit() -> FakeProvider:
|
|
"""Create a FakeProvider that simulates rate limit."""
|
|
return FakeProvider().rate_limit()
|
|
|
|
|
|
def fake_provider_malformed() -> FakeProvider:
|
|
"""Create a FakeProvider that simulates malformed response."""
|
|
return FakeProvider().malformed()
|
|
|
|
|
|
# ============== Mock Model Factory ==============
|
|
|
|
|
|
def fake_model(
|
|
*,
|
|
uuid: str = 'test-model-uuid',
|
|
name: str = 'test-model',
|
|
abilities: list[str] = None,
|
|
provider: FakeProvider = None,
|
|
) -> Mock:
|
|
"""Create a mock model with a fake provider."""
|
|
model = Mock()
|
|
model.model_entity = Mock()
|
|
model.model_entity.uuid = uuid
|
|
model.model_entity.name = name
|
|
model.model_entity.abilities = abilities or ['func_call', 'vision']
|
|
model.model_entity.extra_args = {}
|
|
|
|
# Attach fake provider
|
|
if provider is None:
|
|
provider = FakeProvider()
|
|
|
|
model.provider = provider
|
|
|
|
return model
|