feat(models): persist context metadata

This commit is contained in:
huanghuoguoguo
2026-06-08 00:39:30 +08:00
parent 573e1fe36e
commit b82db2b7f8
23 changed files with 498 additions and 22 deletions
@@ -896,6 +896,121 @@ class TestScanModels:
assert by_id['text-embedding-3-small']['type'] == 'embedding'
assert by_id['bge-reranker-v2']['type'] == 'rerank'
@pytest.mark.asyncio
async def test_scan_models_prefers_context_length_from_provider_payload(self):
"""Provider-supplied context_length is preserved before LiteLLM metadata fallback."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.moonshot.cn/v1',
'timeout': 60,
},
)
requester._supports_function_calling = Mock(return_value=False)
requester._supports_vision = Mock(return_value=False)
requester._safe_context_length = Mock(return_value=None)
mock_response = Mock()
mock_response.json = Mock(
return_value={
'data': [
{'id': 'moonshot-v1-128k', 'context_length': 131072},
]
}
)
mock_response.raise_for_status = Mock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
result = await requester.scan_models(api_key='test-key')
assert result['models'][0]['context_length'] == 131072
requester._safe_context_length.assert_not_called()
def test_safe_context_length_tries_moonshot_metadata_alias(self):
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for context windows."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.moonshot.cn/v1',
'custom_llm_provider': 'openai',
},
)
with patch.object(litellmchat.litellm, 'get_max_tokens') as mock_get_max_tokens:
mock_get_max_tokens.side_effect = lambda model: 131072 if model == 'moonshot/moonshot-v1-128k' else None
assert requester._safe_context_length('moonshot-v1-128k') == 131072
def test_litellm_bool_helper_tries_moonshot_metadata_alias(self):
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for abilities."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.moonshot.cn/v1',
'custom_llm_provider': 'openai',
},
)
with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling:
mock_supports_function_calling.side_effect = (
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6'
and custom_llm_provider is None
)
assert requester._supports_function_calling('kimi-k2.6') is True
@pytest.mark.asyncio
async def test_scan_models_uses_provider_payload_for_vision_ability(self):
"""Provider-supplied vision support is used when scanning models."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.moonshot.cn/v1',
'timeout': 60,
},
)
requester._supports_function_calling = Mock(return_value=False)
requester._supports_vision = Mock(return_value=False)
requester._safe_context_length = Mock(return_value=None)
mock_response = Mock()
mock_response.json = Mock(
return_value={
'data': [
{
'id': 'moonshot-v1-128k-vision-preview',
'supports_image_in': True,
},
]
}
)
mock_response.raise_for_status = Mock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
result = await requester.scan_models(api_key='test-key')
assert result['models'][0]['abilities'] == ['vision']
def test_safe_context_length_falls_back_for_deepseek_v4_models(self):
"""DeepSeek V4 API ids have a known 1M context even before LiteLLM maps them."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.deepseek.com',
'custom_llm_provider': 'deepseek',
},
)
with patch.object(litellmchat.litellm, 'get_max_tokens', side_effect=Exception('not mapped')):
assert requester._safe_context_length('deepseek-v4-pro') == 1_000_000
assert requester._safe_context_length('deepseek-v4-flash') == 1_000_000
@pytest.mark.asyncio
async def test_scan_models_no_base_url(self):
"""Test scan_models without base_url raises error"""
@@ -494,6 +494,7 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
'api_keys': ['temp-key'],
},
'abilities': ['func_call'],
'context_length': 128000,
'extra_args': {'temperature': 0.5},
}
@@ -501,6 +502,9 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
assert runtime_model.model_entity.uuid == 'temp-model-uuid'
assert runtime_model.model_entity.name == 'TempModel'
assert runtime_model.model_entity.context_length == 128000
assert runtime_model.model_entity.extra_args == {'temperature': 0.5}
assert 'context_length' not in runtime_model.model_entity.extra_args
assert runtime_model.provider.provider_entity.uuid == 'temp-provider-uuid'
assert runtime_model.provider.token_mgr.tokens == ['temp-key']
@@ -785,4 +789,4 @@ def test_provider_not_found_error_str():
error = provider_errors.ProviderNotFoundError('test-provider')
assert str(error) == 'Provider test-provider not found'
assert error.provider_name == 'test-provider'
assert error.provider_name == 'test-provider'