Files
LangBot/tests/unit_tests/provider/test_litellmchat.py
huanghuoguoguo de61b5d368 refactor(provider): use LiteLLM as unified LLM requester backend
- Replace 23+ individual requester implementations with unified litellmchat.py
  - Add litellm_provider field to 27 YAML manifests for provider routing
  - Delete redundant requester subclasses
  - Add unit tests for LiteLLMRequester (29 tests)
  - Fix num_retries parameter name (was max_retries)
  - Fix exception handling order for subclass exceptions

  LiteLLM provides unified API for 100+ providers, eliminating need for
  provider-specific requesters.
2026-04-24 22:32:33 +08:00

636 lines
23 KiB
Python

"""
Tests for LiteLLMRequester - unified requester for chat, embedding, and rerank.
These tests verify:
- Parameter building and LiteLLM API calls
- Response processing and usage extraction
- Error handling and exception translation
- Model name building with provider prefix
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
import litellm
from langbot.pkg.provider.modelmgr.requesters import litellmchat
from langbot.pkg.provider.modelmgr import errors
class MockRuntimeModel:
"""Mock RuntimeLLMModel for testing"""
def __init__(self, model_name: str = 'gpt-4o', api_key: str = 'test-key'):
self.model_entity = Mock()
self.model_entity.name = model_name
self.model_entity.extra_args = {}
self.provider = Mock()
self.provider.token_mgr = Mock()
self.provider.token_mgr.get_token = Mock(return_value=api_key)
class MockRuntimeEmbeddingModel:
"""Mock RuntimeEmbeddingModel for testing"""
def __init__(self, model_name: str = 'text-embedding-3-small', api_key: str = 'test-key'):
self.model_entity = Mock()
self.model_entity.name = model_name
self.model_entity.extra_args = {}
self.provider = Mock()
self.provider.token_mgr = Mock()
self.provider.token_mgr.get_token = Mock(return_value=api_key)
class MockRuntimeRerankModel:
"""Mock RuntimeRerankModel for testing"""
def __init__(self, model_name: str = 'cohere/rerank-english-v3.0', api_key: str = 'test-key'):
self.model_entity = Mock()
self.model_entity.name = model_name
self.model_entity.extra_args = {}
self.provider = Mock()
self.provider.token_mgr = Mock()
self.provider.token_mgr.get_token = Mock(return_value=api_key)
class TestBuildLiteLLMModelName:
"""Test _build_litellm_model_name method"""
def test_no_provider_prefix(self):
"""Test model name without provider prefix"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': ''})
result = requester._build_litellm_model_name('gpt-4o')
assert result == 'gpt-4o'
def test_with_provider_prefix(self):
"""Test model name with provider prefix"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
result = requester._build_litellm_model_name('gpt-4o')
assert result == 'openai/gpt-4o'
def test_override_provider(self):
"""Test override provider via parameter"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
result = requester._build_litellm_model_name('claude-3', custom_llm_provider='anthropic')
assert result == 'anthropic/claude-3'
class TestExtractUsage:
"""Test _extract_usage method"""
def test_extract_usage_with_data(self):
"""Test extraction with valid usage data"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
response = Mock()
response.usage = Mock()
response.usage.prompt_tokens = 100
response.usage.completion_tokens = 50
response.usage.total_tokens = 150
result = requester._extract_usage(response)
assert result['prompt_tokens'] == 100
assert result['completion_tokens'] == 50
assert result['total_tokens'] == 150
assert result['input_tokens'] == 100 # Compatibility alias
assert result['output_tokens'] == 50 # Compatibility alias
def test_extract_usage_with_zero_values(self):
"""Test extraction when values are 0"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
response = Mock()
response.usage = Mock()
response.usage.prompt_tokens = 0
response.usage.completion_tokens = 0
response.usage.total_tokens = 0
result = requester._extract_usage(response)
assert result['prompt_tokens'] == 0
assert result['completion_tokens'] == 0
class TestProcessThinkingContent:
"""Test _process_thinking_content method"""
def test_no_thinking_markers(self):
"""Test content without thinking markers"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
result = requester._process_thinking_content('Hello world', None, remove_think=True)
assert result == 'Hello world'
def test_remove_thinking_markers(self):
"""Test removing thinking markers when remove_think=True"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.'
result = requester._process_thinking_content(content, None, remove_think=True)
assert result == 'The answer is 42.'
def test_preserve_thinking_markers(self):
"""Test preserving thinking markers when remove_think=False"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.'
result = requester._process_thinking_content(content, None, remove_think=False)
assert 'CRETIRE_REASONING_BEGINk' in result
assert 'The answer is 42.' in result
def test_empty_content(self):
"""Test empty content"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
result = requester._process_thinking_content('', None, remove_think=True)
assert result == ''
class TestBuildCommonArgs:
"""Test _build_common_args method"""
def test_build_args_with_all_params(self):
"""Test building args with all config params"""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.openai.com/v1',
'timeout': 60,
'drop_params': True,
'num_retries': 3,
'api_version': '2024-01-01',
},
)
args = {}
requester._build_common_args(args)
assert args['api_base'] == 'https://api.openai.com/v1'
assert args['timeout'] == 60
assert args['drop_params'] == True
assert args['num_retries'] == 3
assert args['api_version'] == '2024-01-01'
def test_build_args_without_retry_params(self):
"""Test building args without retry params for embedding/rerank"""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.openai.com/v1',
'timeout': 60,
'num_retries': 3,
},
)
args = {}
requester._build_common_args(args, include_retry_params=False)
assert args['api_base'] == 'https://api.openai.com/v1'
assert args['timeout'] == 60
assert 'num_retries' not in args
class TestHandleLiteLLMError:
"""Test _handle_litellm_error method"""
def test_bad_request_error(self):
"""Test BadRequestError translation"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
# Create proper LiteLLM exception with required args
error = litellm.BadRequestError(message='test error', model='gpt-4o', llm_provider='openai')
with pytest.raises(errors.RequesterError) as exc_info:
requester._handle_litellm_error(error)
assert '请求参数错误' in str(exc_info.value)
def test_authentication_error(self):
"""Test AuthenticationError translation"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai')
with pytest.raises(errors.RequesterError) as exc_info:
requester._handle_litellm_error(error)
assert 'API key 无效' in str(exc_info.value)
def test_rate_limit_error(self):
"""Test RateLimitError translation"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
error = litellm.RateLimitError(message='rate limited', model='gpt-4o', llm_provider='openai')
with pytest.raises(errors.RequesterError) as exc_info:
requester._handle_litellm_error(error)
assert '请求过于频繁' in str(exc_info.value)
def test_timeout_error(self):
"""Test Timeout translation"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
error = litellm.Timeout(message='timeout', model='gpt-4o', llm_provider='openai')
with pytest.raises(errors.RequesterError) as exc_info:
requester._handle_litellm_error(error)
assert '请求超时' in str(exc_info.value)
def test_context_window_error(self):
"""Test ContextWindowExceededError translation"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
error = litellm.ContextWindowExceededError(message='context too long', model='gpt-4o', llm_provider='openai')
with pytest.raises(errors.RequesterError) as exc_info:
requester._handle_litellm_error(error)
assert '上下文长度超限' in str(exc_info.value)
def test_unknown_error(self):
"""Test unknown error translation"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
with pytest.raises(errors.RequesterError) as exc_info:
requester._handle_litellm_error(Exception('unknown'))
assert '未知错误' in str(exc_info.value)
class TestInvokeLLM:
"""Test invoke_llm method"""
@pytest.mark.asyncio
async def test_invoke_llm_basic(self):
"""Test basic LLM invocation"""
mock_ap = Mock()
mock_ap.tool_mgr = Mock()
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None)
requester = litellmchat.LiteLLMRequester(
ap=mock_ap,
config={
'base_url': 'https://api.openai.com/v1',
'timeout': 60,
},
)
model = MockRuntimeModel('gpt-4o', 'test-api-key')
# Mock LiteLLM response
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message = Mock()
mock_response.choices[0].message.model_dump = Mock(
return_value={
'role': 'assistant',
'content': 'Hello! How can I help you?',
}
)
mock_response.usage = Mock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_response.usage.total_tokens = 30
import langbot_plugin.api.entities.builtin.provider.message as provider_message
messages = [provider_message.Message(role='user', content='Hello')]
# Patch acompletion at the import location
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response):
result_msg, usage = await requester.invoke_llm(
query=None,
model=model,
messages=messages,
)
assert result_msg.role == 'assistant'
assert result_msg.content == 'Hello! How can I help you?'
assert usage['prompt_tokens'] == 10
assert usage['completion_tokens'] == 20
@pytest.mark.asyncio
async def test_invoke_llm_with_tools(self):
"""Test LLM invocation with function calling"""
mock_ap = Mock()
mock_ap.tool_mgr = Mock()
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}]
)
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
model = MockRuntimeModel('gpt-4o', 'test-api-key')
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message = Mock()
mock_response.choices[0].message.model_dump = Mock(
return_value={
'role': 'assistant',
'content': None,
'tool_calls': [
{'id': 'call_123', 'type': 'function', 'function': {'name': 'get_weather', 'arguments': '{}'}}
],
}
)
mock_response.usage = Mock()
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_response.usage.total_tokens = 25
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.provider.message as provider_message
messages = [provider_message.Message(role='user', content='What is the weather?')]
# Create proper LLMTool with all required fields
funcs = [Mock(spec=resource_tool.LLMTool)]
funcs[0].name = 'get_weather'
funcs[0].description = 'Get weather'
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response):
result_msg, usage = await requester.invoke_llm(
query=None,
model=model,
messages=messages,
funcs=funcs,
)
assert result_msg.tool_calls is not None
@pytest.mark.asyncio
async def test_invoke_llm_error_handling(self):
"""Test LLM invocation error handling"""
mock_ap = Mock()
mock_ap.tool_mgr = Mock()
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None)
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
model = MockRuntimeModel('gpt-4o', 'test-api-key')
import langbot_plugin.api.entities.builtin.provider.message as provider_message
messages = [provider_message.Message(role='user', content='Hello')]
error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai')
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, side_effect=error):
with pytest.raises(errors.RequesterError) as exc_info:
await requester.invoke_llm(
query=None,
model=model,
messages=messages,
)
assert 'API key 无效' in str(exc_info.value)
class TestInvokeEmbedding:
"""Test invoke_embedding method"""
@pytest.mark.asyncio
async def test_invoke_embedding_basic(self):
"""Test basic embedding invocation"""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.openai.com/v1',
},
)
model = MockRuntimeEmbeddingModel('text-embedding-3-small', 'test-api-key')
# Mock LiteLLM embedding response
mock_response = Mock()
mock_response.data = [
Mock(embedding=[0.1, 0.2, 0.3]),
Mock(embedding=[0.4, 0.5, 0.6]),
]
mock_response.usage = Mock()
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 0
mock_response.usage.total_tokens = 20
with patch.object(litellmchat, 'aembedding', new_callable=AsyncMock, return_value=mock_response):
embeddings, usage = await requester.invoke_embedding(
model=model,
input_text=['Hello', 'World'],
)
assert len(embeddings) == 2
assert embeddings[0] == [0.1, 0.2, 0.3]
assert embeddings[1] == [0.4, 0.5, 0.6]
assert usage['prompt_tokens'] == 20
class TestInvokeRerank:
"""Test invoke_rerank method"""
@pytest.mark.asyncio
async def test_invoke_rerank_basic(self):
"""Test basic rerank invocation"""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.cohere.ai',
},
)
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
# Mock LiteLLM rerank response
mock_response = Mock()
mock_response.results = [
{'index': 0, 'relevance_score': 0.95},
{'index': 1, 'relevance_score': 0.3},
{'index': 2, 'relevance_score': 0.8},
]
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
results = await requester.invoke_rerank(
model=model,
query='What is the capital of France?',
documents=['Paris is the capital.', 'London is a city.', 'France is in Europe.'],
)
assert len(results) == 3
# Scores should be normalized
assert results[0]['index'] == 0
assert results[0]['relevance_score'] >= 0 and results[0]['relevance_score'] <= 1
@pytest.mark.asyncio
async def test_invoke_rerank_normalization(self):
"""Test rerank score normalization"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
# Mock response with varying scores
mock_response = Mock()
mock_response.results = [
{'index': 0, 'relevance_score': 0.9},
{'index': 1, 'relevance_score': 0.1},
]
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
results = await requester.invoke_rerank(
model=model,
query='test query',
documents=['doc1', 'doc2'],
)
# After normalization: 0.9 -> 1.0, 0.1 -> 0.0
assert results[0]['relevance_score'] == 1.0
assert results[1]['relevance_score'] == 0.0
@pytest.mark.asyncio
async def test_invoke_rerank_single_document(self):
"""Test rerank with single document (no normalization needed)"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
mock_response = Mock()
mock_response.results = [
{'index': 0, 'relevance_score': 0.5},
]
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
results = await requester.invoke_rerank(
model=model,
query='test query',
documents=['doc1'],
)
assert len(results) == 1
# Single score stays as is (min==max, no normalization)
assert results[0]['relevance_score'] == 0.5
class TestConvertMessages:
"""Test _convert_messages method"""
def test_convert_simple_message(self):
"""Test converting simple text message"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
import langbot_plugin.api.entities.builtin.provider.message as provider_message
messages = [provider_message.Message(role='user', content='Hello')]
result = requester._convert_messages(messages)
assert len(result) == 1
assert result[0]['role'] == 'user'
assert result[0]['content'] == 'Hello'
def test_convert_message_with_image_base64(self):
"""Test converting message with image_base64 content"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
import langbot_plugin.api.entities.builtin.provider.message as provider_message
messages = [
provider_message.Message(
role='user',
content=[
{'type': 'text', 'text': 'What is in this image?'},
{'type': 'image_base64', 'image_base64': 'data:image/png;base64,abc123'},
],
)
]
result = requester._convert_messages(messages)
assert len(result) == 1
content = result[0]['content']
assert isinstance(content, list)
# Check image_base64 converted to image_url
image_part = [p for p in content if p.get('type') == 'image_url'][0]
assert 'image_url' in image_part
assert image_part['image_url']['url'] == 'data:image/png;base64,abc123'
def test_convert_message_with_multiple_text_parts(self):
"""Test converting message with multiple text parts (LiteLLM handles this)"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
import langbot_plugin.api.entities.builtin.provider.message as provider_message
messages = [
provider_message.Message(
role='user',
content=[
{'type': 'text', 'text': 'Hello'},
{'type': 'text', 'text': 'World'},
],
)
]
result = requester._convert_messages(messages)
assert len(result) == 1
# LiteLLM handles multiple text parts, we pass them through
assert isinstance(result[0]['content'], list)
class TestScanModels:
"""Test scan_models method"""
@pytest.mark.asyncio
async def test_scan_models_basic(self):
"""Test basic model scanning"""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.openai.com/v1',
'timeout': 60,
},
)
# Mock httpx response
mock_response = Mock()
mock_response.json = Mock(
return_value={
'data': [
{'id': 'gpt-4o'},
{'id': 'text-embedding-3-small'},
{'id': 'gpt-3.5-turbo'},
]
}
)
mock_response.raise_for_status = Mock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
result = await requester.scan_models(api_key='test-key')
assert 'models' in result
assert len(result['models']) == 3
# Check LLM models are first
assert result['models'][0]['type'] == 'llm'
# Check embedding model is detected
embedding_models = [m for m in result['models'] if m['type'] == 'embedding']
assert len(embedding_models) == 1
@pytest.mark.asyncio
async def test_scan_models_no_base_url(self):
"""Test scan_models without base_url raises error"""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': '',
},
)
with pytest.raises(errors.RequesterError) as exc_info:
await requester.scan_models()
assert 'Base URL required' in str(exc_info.value)
if __name__ == '__main__':
pytest.main([__file__, '-v'])