mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
- Replace 23+ individual requester implementations with unified litellmchat.py - Add litellm_provider field to 27 YAML manifests for provider routing - Delete redundant requester subclasses - Add unit tests for LiteLLMRequester (29 tests) - Fix num_retries parameter name (was max_retries) - Fix exception handling order for subclass exceptions LiteLLM provides unified API for 100+ providers, eliminating need for provider-specific requesters.
636 lines
23 KiB
Python
636 lines
23 KiB
Python
"""
|
|
Tests for LiteLLMRequester - unified requester for chat, embedding, and rerank.
|
|
|
|
These tests verify:
|
|
- Parameter building and LiteLLM API calls
|
|
- Response processing and usage extraction
|
|
- Error handling and exception translation
|
|
- Model name building with provider prefix
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
|
|
import litellm
|
|
|
|
from langbot.pkg.provider.modelmgr.requesters import litellmchat
|
|
from langbot.pkg.provider.modelmgr import errors
|
|
|
|
|
|
class MockRuntimeModel:
|
|
"""Mock RuntimeLLMModel for testing"""
|
|
|
|
def __init__(self, model_name: str = 'gpt-4o', api_key: str = 'test-key'):
|
|
self.model_entity = Mock()
|
|
self.model_entity.name = model_name
|
|
self.model_entity.extra_args = {}
|
|
self.provider = Mock()
|
|
self.provider.token_mgr = Mock()
|
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
|
|
|
|
|
class MockRuntimeEmbeddingModel:
|
|
"""Mock RuntimeEmbeddingModel for testing"""
|
|
|
|
def __init__(self, model_name: str = 'text-embedding-3-small', api_key: str = 'test-key'):
|
|
self.model_entity = Mock()
|
|
self.model_entity.name = model_name
|
|
self.model_entity.extra_args = {}
|
|
self.provider = Mock()
|
|
self.provider.token_mgr = Mock()
|
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
|
|
|
|
|
class MockRuntimeRerankModel:
|
|
"""Mock RuntimeRerankModel for testing"""
|
|
|
|
def __init__(self, model_name: str = 'cohere/rerank-english-v3.0', api_key: str = 'test-key'):
|
|
self.model_entity = Mock()
|
|
self.model_entity.name = model_name
|
|
self.model_entity.extra_args = {}
|
|
self.provider = Mock()
|
|
self.provider.token_mgr = Mock()
|
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
|
|
|
|
|
class TestBuildLiteLLMModelName:
|
|
"""Test _build_litellm_model_name method"""
|
|
|
|
def test_no_provider_prefix(self):
|
|
"""Test model name without provider prefix"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': ''})
|
|
result = requester._build_litellm_model_name('gpt-4o')
|
|
assert result == 'gpt-4o'
|
|
|
|
def test_with_provider_prefix(self):
|
|
"""Test model name with provider prefix"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
|
result = requester._build_litellm_model_name('gpt-4o')
|
|
assert result == 'openai/gpt-4o'
|
|
|
|
def test_override_provider(self):
|
|
"""Test override provider via parameter"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
|
result = requester._build_litellm_model_name('claude-3', custom_llm_provider='anthropic')
|
|
assert result == 'anthropic/claude-3'
|
|
|
|
|
|
class TestExtractUsage:
|
|
"""Test _extract_usage method"""
|
|
|
|
def test_extract_usage_with_data(self):
|
|
"""Test extraction with valid usage data"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
response = Mock()
|
|
response.usage = Mock()
|
|
response.usage.prompt_tokens = 100
|
|
response.usage.completion_tokens = 50
|
|
response.usage.total_tokens = 150
|
|
|
|
result = requester._extract_usage(response)
|
|
|
|
assert result['prompt_tokens'] == 100
|
|
assert result['completion_tokens'] == 50
|
|
assert result['total_tokens'] == 150
|
|
assert result['input_tokens'] == 100 # Compatibility alias
|
|
assert result['output_tokens'] == 50 # Compatibility alias
|
|
|
|
def test_extract_usage_with_zero_values(self):
|
|
"""Test extraction when values are 0"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
response = Mock()
|
|
response.usage = Mock()
|
|
response.usage.prompt_tokens = 0
|
|
response.usage.completion_tokens = 0
|
|
response.usage.total_tokens = 0
|
|
|
|
result = requester._extract_usage(response)
|
|
|
|
assert result['prompt_tokens'] == 0
|
|
assert result['completion_tokens'] == 0
|
|
|
|
|
|
class TestProcessThinkingContent:
|
|
"""Test _process_thinking_content method"""
|
|
|
|
def test_no_thinking_markers(self):
|
|
"""Test content without thinking markers"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
result = requester._process_thinking_content('Hello world', None, remove_think=True)
|
|
assert result == 'Hello world'
|
|
|
|
def test_remove_thinking_markers(self):
|
|
"""Test removing thinking markers when remove_think=True"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.'
|
|
result = requester._process_thinking_content(content, None, remove_think=True)
|
|
assert result == 'The answer is 42.'
|
|
|
|
def test_preserve_thinking_markers(self):
|
|
"""Test preserving thinking markers when remove_think=False"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.'
|
|
result = requester._process_thinking_content(content, None, remove_think=False)
|
|
assert 'CRETIRE_REASONING_BEGINk' in result
|
|
assert 'The answer is 42.' in result
|
|
|
|
def test_empty_content(self):
|
|
"""Test empty content"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
result = requester._process_thinking_content('', None, remove_think=True)
|
|
assert result == ''
|
|
|
|
|
|
class TestBuildCommonArgs:
|
|
"""Test _build_common_args method"""
|
|
|
|
def test_build_args_with_all_params(self):
|
|
"""Test building args with all config params"""
|
|
requester = litellmchat.LiteLLMRequester(
|
|
ap=Mock(),
|
|
config={
|
|
'base_url': 'https://api.openai.com/v1',
|
|
'timeout': 60,
|
|
'drop_params': True,
|
|
'num_retries': 3,
|
|
'api_version': '2024-01-01',
|
|
},
|
|
)
|
|
|
|
args = {}
|
|
requester._build_common_args(args)
|
|
|
|
assert args['api_base'] == 'https://api.openai.com/v1'
|
|
assert args['timeout'] == 60
|
|
assert args['drop_params'] == True
|
|
assert args['num_retries'] == 3
|
|
assert args['api_version'] == '2024-01-01'
|
|
|
|
def test_build_args_without_retry_params(self):
|
|
"""Test building args without retry params for embedding/rerank"""
|
|
requester = litellmchat.LiteLLMRequester(
|
|
ap=Mock(),
|
|
config={
|
|
'base_url': 'https://api.openai.com/v1',
|
|
'timeout': 60,
|
|
'num_retries': 3,
|
|
},
|
|
)
|
|
|
|
args = {}
|
|
requester._build_common_args(args, include_retry_params=False)
|
|
|
|
assert args['api_base'] == 'https://api.openai.com/v1'
|
|
assert args['timeout'] == 60
|
|
assert 'num_retries' not in args
|
|
|
|
|
|
class TestHandleLiteLLMError:
|
|
"""Test _handle_litellm_error method"""
|
|
|
|
def test_bad_request_error(self):
|
|
"""Test BadRequestError translation"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
# Create proper LiteLLM exception with required args
|
|
error = litellm.BadRequestError(message='test error', model='gpt-4o', llm_provider='openai')
|
|
|
|
with pytest.raises(errors.RequesterError) as exc_info:
|
|
requester._handle_litellm_error(error)
|
|
|
|
assert '请求参数错误' in str(exc_info.value)
|
|
|
|
def test_authentication_error(self):
|
|
"""Test AuthenticationError translation"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai')
|
|
|
|
with pytest.raises(errors.RequesterError) as exc_info:
|
|
requester._handle_litellm_error(error)
|
|
|
|
assert 'API key 无效' in str(exc_info.value)
|
|
|
|
def test_rate_limit_error(self):
|
|
"""Test RateLimitError translation"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
error = litellm.RateLimitError(message='rate limited', model='gpt-4o', llm_provider='openai')
|
|
|
|
with pytest.raises(errors.RequesterError) as exc_info:
|
|
requester._handle_litellm_error(error)
|
|
|
|
assert '请求过于频繁' in str(exc_info.value)
|
|
|
|
def test_timeout_error(self):
|
|
"""Test Timeout translation"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
error = litellm.Timeout(message='timeout', model='gpt-4o', llm_provider='openai')
|
|
|
|
with pytest.raises(errors.RequesterError) as exc_info:
|
|
requester._handle_litellm_error(error)
|
|
|
|
assert '请求超时' in str(exc_info.value)
|
|
|
|
def test_context_window_error(self):
|
|
"""Test ContextWindowExceededError translation"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
error = litellm.ContextWindowExceededError(message='context too long', model='gpt-4o', llm_provider='openai')
|
|
|
|
with pytest.raises(errors.RequesterError) as exc_info:
|
|
requester._handle_litellm_error(error)
|
|
|
|
assert '上下文长度超限' in str(exc_info.value)
|
|
|
|
def test_unknown_error(self):
|
|
"""Test unknown error translation"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
with pytest.raises(errors.RequesterError) as exc_info:
|
|
requester._handle_litellm_error(Exception('unknown'))
|
|
|
|
assert '未知错误' in str(exc_info.value)
|
|
|
|
|
|
class TestInvokeLLM:
|
|
"""Test invoke_llm method"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invoke_llm_basic(self):
|
|
"""Test basic LLM invocation"""
|
|
mock_ap = Mock()
|
|
mock_ap.tool_mgr = Mock()
|
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None)
|
|
|
|
requester = litellmchat.LiteLLMRequester(
|
|
ap=mock_ap,
|
|
config={
|
|
'base_url': 'https://api.openai.com/v1',
|
|
'timeout': 60,
|
|
},
|
|
)
|
|
|
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
|
|
|
# Mock LiteLLM response
|
|
mock_response = Mock()
|
|
mock_response.choices = [Mock()]
|
|
mock_response.choices[0].message = Mock()
|
|
mock_response.choices[0].message.model_dump = Mock(
|
|
return_value={
|
|
'role': 'assistant',
|
|
'content': 'Hello! How can I help you?',
|
|
}
|
|
)
|
|
mock_response.usage = Mock()
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 20
|
|
mock_response.usage.total_tokens = 30
|
|
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
messages = [provider_message.Message(role='user', content='Hello')]
|
|
|
|
# Patch acompletion at the import location
|
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response):
|
|
result_msg, usage = await requester.invoke_llm(
|
|
query=None,
|
|
model=model,
|
|
messages=messages,
|
|
)
|
|
|
|
assert result_msg.role == 'assistant'
|
|
assert result_msg.content == 'Hello! How can I help you?'
|
|
assert usage['prompt_tokens'] == 10
|
|
assert usage['completion_tokens'] == 20
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invoke_llm_with_tools(self):
|
|
"""Test LLM invocation with function calling"""
|
|
mock_ap = Mock()
|
|
mock_ap.tool_mgr = Mock()
|
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
|
|
return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}]
|
|
)
|
|
|
|
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
|
|
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
|
|
|
mock_response = Mock()
|
|
mock_response.choices = [Mock()]
|
|
mock_response.choices[0].message = Mock()
|
|
mock_response.choices[0].message.model_dump = Mock(
|
|
return_value={
|
|
'role': 'assistant',
|
|
'content': None,
|
|
'tool_calls': [
|
|
{'id': 'call_123', 'type': 'function', 'function': {'name': 'get_weather', 'arguments': '{}'}}
|
|
],
|
|
}
|
|
)
|
|
mock_response.usage = Mock()
|
|
mock_response.usage.prompt_tokens = 15
|
|
mock_response.usage.completion_tokens = 10
|
|
mock_response.usage.total_tokens = 25
|
|
|
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
messages = [provider_message.Message(role='user', content='What is the weather?')]
|
|
# Create proper LLMTool with all required fields
|
|
funcs = [Mock(spec=resource_tool.LLMTool)]
|
|
funcs[0].name = 'get_weather'
|
|
funcs[0].description = 'Get weather'
|
|
|
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response):
|
|
result_msg, usage = await requester.invoke_llm(
|
|
query=None,
|
|
model=model,
|
|
messages=messages,
|
|
funcs=funcs,
|
|
)
|
|
|
|
assert result_msg.tool_calls is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invoke_llm_error_handling(self):
|
|
"""Test LLM invocation error handling"""
|
|
mock_ap = Mock()
|
|
mock_ap.tool_mgr = Mock()
|
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None)
|
|
|
|
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
|
|
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
|
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
messages = [provider_message.Message(role='user', content='Hello')]
|
|
|
|
error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai')
|
|
|
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, side_effect=error):
|
|
with pytest.raises(errors.RequesterError) as exc_info:
|
|
await requester.invoke_llm(
|
|
query=None,
|
|
model=model,
|
|
messages=messages,
|
|
)
|
|
|
|
assert 'API key 无效' in str(exc_info.value)
|
|
|
|
|
|
class TestInvokeEmbedding:
|
|
"""Test invoke_embedding method"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invoke_embedding_basic(self):
|
|
"""Test basic embedding invocation"""
|
|
requester = litellmchat.LiteLLMRequester(
|
|
ap=Mock(),
|
|
config={
|
|
'base_url': 'https://api.openai.com/v1',
|
|
},
|
|
)
|
|
|
|
model = MockRuntimeEmbeddingModel('text-embedding-3-small', 'test-api-key')
|
|
|
|
# Mock LiteLLM embedding response
|
|
mock_response = Mock()
|
|
mock_response.data = [
|
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
|
]
|
|
mock_response.usage = Mock()
|
|
mock_response.usage.prompt_tokens = 20
|
|
mock_response.usage.completion_tokens = 0
|
|
mock_response.usage.total_tokens = 20
|
|
|
|
with patch.object(litellmchat, 'aembedding', new_callable=AsyncMock, return_value=mock_response):
|
|
embeddings, usage = await requester.invoke_embedding(
|
|
model=model,
|
|
input_text=['Hello', 'World'],
|
|
)
|
|
|
|
assert len(embeddings) == 2
|
|
assert embeddings[0] == [0.1, 0.2, 0.3]
|
|
assert embeddings[1] == [0.4, 0.5, 0.6]
|
|
assert usage['prompt_tokens'] == 20
|
|
|
|
|
|
class TestInvokeRerank:
|
|
"""Test invoke_rerank method"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invoke_rerank_basic(self):
|
|
"""Test basic rerank invocation"""
|
|
requester = litellmchat.LiteLLMRequester(
|
|
ap=Mock(),
|
|
config={
|
|
'base_url': 'https://api.cohere.ai',
|
|
},
|
|
)
|
|
|
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
|
|
|
# Mock LiteLLM rerank response
|
|
mock_response = Mock()
|
|
mock_response.results = [
|
|
{'index': 0, 'relevance_score': 0.95},
|
|
{'index': 1, 'relevance_score': 0.3},
|
|
{'index': 2, 'relevance_score': 0.8},
|
|
]
|
|
|
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
|
results = await requester.invoke_rerank(
|
|
model=model,
|
|
query='What is the capital of France?',
|
|
documents=['Paris is the capital.', 'London is a city.', 'France is in Europe.'],
|
|
)
|
|
|
|
assert len(results) == 3
|
|
# Scores should be normalized
|
|
assert results[0]['index'] == 0
|
|
assert results[0]['relevance_score'] >= 0 and results[0]['relevance_score'] <= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invoke_rerank_normalization(self):
|
|
"""Test rerank score normalization"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
|
|
|
# Mock response with varying scores
|
|
mock_response = Mock()
|
|
mock_response.results = [
|
|
{'index': 0, 'relevance_score': 0.9},
|
|
{'index': 1, 'relevance_score': 0.1},
|
|
]
|
|
|
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
|
results = await requester.invoke_rerank(
|
|
model=model,
|
|
query='test query',
|
|
documents=['doc1', 'doc2'],
|
|
)
|
|
|
|
# After normalization: 0.9 -> 1.0, 0.1 -> 0.0
|
|
assert results[0]['relevance_score'] == 1.0
|
|
assert results[1]['relevance_score'] == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invoke_rerank_single_document(self):
|
|
"""Test rerank with single document (no normalization needed)"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
|
|
|
mock_response = Mock()
|
|
mock_response.results = [
|
|
{'index': 0, 'relevance_score': 0.5},
|
|
]
|
|
|
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
|
results = await requester.invoke_rerank(
|
|
model=model,
|
|
query='test query',
|
|
documents=['doc1'],
|
|
)
|
|
|
|
assert len(results) == 1
|
|
# Single score stays as is (min==max, no normalization)
|
|
assert results[0]['relevance_score'] == 0.5
|
|
|
|
|
|
class TestConvertMessages:
|
|
"""Test _convert_messages method"""
|
|
|
|
def test_convert_simple_message(self):
|
|
"""Test converting simple text message"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
messages = [provider_message.Message(role='user', content='Hello')]
|
|
result = requester._convert_messages(messages)
|
|
|
|
assert len(result) == 1
|
|
assert result[0]['role'] == 'user'
|
|
assert result[0]['content'] == 'Hello'
|
|
|
|
def test_convert_message_with_image_base64(self):
|
|
"""Test converting message with image_base64 content"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
messages = [
|
|
provider_message.Message(
|
|
role='user',
|
|
content=[
|
|
{'type': 'text', 'text': 'What is in this image?'},
|
|
{'type': 'image_base64', 'image_base64': 'data:image/png;base64,abc123'},
|
|
],
|
|
)
|
|
]
|
|
result = requester._convert_messages(messages)
|
|
|
|
assert len(result) == 1
|
|
content = result[0]['content']
|
|
assert isinstance(content, list)
|
|
# Check image_base64 converted to image_url
|
|
image_part = [p for p in content if p.get('type') == 'image_url'][0]
|
|
assert 'image_url' in image_part
|
|
assert image_part['image_url']['url'] == 'data:image/png;base64,abc123'
|
|
|
|
def test_convert_message_with_multiple_text_parts(self):
|
|
"""Test converting message with multiple text parts (LiteLLM handles this)"""
|
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
|
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
messages = [
|
|
provider_message.Message(
|
|
role='user',
|
|
content=[
|
|
{'type': 'text', 'text': 'Hello'},
|
|
{'type': 'text', 'text': 'World'},
|
|
],
|
|
)
|
|
]
|
|
result = requester._convert_messages(messages)
|
|
|
|
assert len(result) == 1
|
|
# LiteLLM handles multiple text parts, we pass them through
|
|
assert isinstance(result[0]['content'], list)
|
|
|
|
|
|
class TestScanModels:
|
|
"""Test scan_models method"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_models_basic(self):
|
|
"""Test basic model scanning"""
|
|
requester = litellmchat.LiteLLMRequester(
|
|
ap=Mock(),
|
|
config={
|
|
'base_url': 'https://api.openai.com/v1',
|
|
'timeout': 60,
|
|
},
|
|
)
|
|
|
|
# Mock httpx response
|
|
mock_response = Mock()
|
|
mock_response.json = Mock(
|
|
return_value={
|
|
'data': [
|
|
{'id': 'gpt-4o'},
|
|
{'id': 'text-embedding-3-small'},
|
|
{'id': 'gpt-3.5-turbo'},
|
|
]
|
|
}
|
|
)
|
|
mock_response.raise_for_status = Mock()
|
|
|
|
with patch('httpx.AsyncClient') as mock_client:
|
|
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
|
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
|
|
|
result = await requester.scan_models(api_key='test-key')
|
|
|
|
assert 'models' in result
|
|
assert len(result['models']) == 3
|
|
# Check LLM models are first
|
|
assert result['models'][0]['type'] == 'llm'
|
|
# Check embedding model is detected
|
|
embedding_models = [m for m in result['models'] if m['type'] == 'embedding']
|
|
assert len(embedding_models) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_models_no_base_url(self):
|
|
"""Test scan_models without base_url raises error"""
|
|
requester = litellmchat.LiteLLMRequester(
|
|
ap=Mock(),
|
|
config={
|
|
'base_url': '',
|
|
},
|
|
)
|
|
|
|
with pytest.raises(errors.RequesterError) as exc_info:
|
|
await requester.scan_models()
|
|
|
|
assert 'Base URL required' in str(exc_info.value)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v'])
|