mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-10 15:56:03 +00:00
feat(models): persist context metadata
This commit is contained in:
@@ -35,6 +35,7 @@ def _create_mock_llm_model(
|
||||
name: str = 'Test LLM',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
abilities: list = None,
|
||||
context_length: int | None = None,
|
||||
extra_args: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock LLMModel entity."""
|
||||
@@ -43,6 +44,7 @@ def _create_mock_llm_model(
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.abilities = abilities or []
|
||||
model.context_length = context_length
|
||||
model.extra_args = extra_args or {}
|
||||
return model
|
||||
|
||||
@@ -142,10 +144,12 @@ class TestRuntimeModelData:
|
||||
'name': 'Model',
|
||||
'provider_uuid': 'provider',
|
||||
'abilities': ['vision'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temp': 0.7},
|
||||
}
|
||||
result = _runtime_model_data('uuid', update_payload)
|
||||
assert result['abilities'] == ['vision']
|
||||
assert result['context_length'] == 128000
|
||||
assert result['extra_args'] == {'temp': 0.7}
|
||||
|
||||
|
||||
@@ -188,7 +192,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model()
|
||||
model = _create_mock_llm_model(context_length=128000)
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
@@ -206,6 +210,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
|
||||
'context_length': getattr(entity, 'context_length', None),
|
||||
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
|
||||
}
|
||||
)
|
||||
@@ -218,6 +223,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0]['name'] == 'Test LLM'
|
||||
assert result[0]['context_length'] == 128000
|
||||
|
||||
async def test_get_llm_models_hide_secret_keys(self):
|
||||
"""Hides secret API keys when include_secret=False."""
|
||||
@@ -265,7 +271,7 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model(model_uuid='found-uuid')
|
||||
model = _create_mock_llm_model(model_uuid='found-uuid', context_length=128000)
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
@@ -279,11 +285,12 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-uuid',
|
||||
'name': 'Test LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']},
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': getattr(entity, 'provider_uuid', None),
|
||||
'context_length': getattr(entity, 'context_length', None),
|
||||
'api_keys': getattr(entity, 'api_keys', None),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -295,6 +302,7 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'found-uuid'
|
||||
assert result['context_length'] == 128000
|
||||
|
||||
async def test_get_llm_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
@@ -402,6 +410,39 @@ class TestLLMModelsServiceCreateLLMModel:
|
||||
# Verify
|
||||
assert model_uuid == 'preserved-uuid'
|
||||
|
||||
async def test_create_llm_model_persists_context_length_as_column(self):
|
||||
"""Creates LLM model with context_length outside extra_args."""
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock())
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
await service.create_llm_model(
|
||||
{
|
||||
'uuid': 'model-with-context',
|
||||
'name': 'Context Model',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': ['func_call'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temperature': 0.2},
|
||||
},
|
||||
preserve_uuid=True,
|
||||
auto_set_to_default_pipeline=False,
|
||||
)
|
||||
|
||||
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
|
||||
assert runtime_entity.context_length == 128000
|
||||
assert runtime_entity.extra_args == {'temperature': 0.2}
|
||||
assert 'context_length' not in runtime_entity.extra_args
|
||||
|
||||
async def test_create_llm_model_provider_not_found_raises_error(self):
|
||||
"""Raises Exception when provider not found in runtime."""
|
||||
# Setup
|
||||
@@ -512,6 +553,35 @@ class TestLLMModelsServiceUpdateLLMModel:
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
})
|
||||
|
||||
async def test_update_llm_model_reloads_context_length_as_column(self):
|
||||
"""Updates runtime model with context_length outside extra_args."""
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
await service.update_llm_model(
|
||||
'existing-uuid',
|
||||
{
|
||||
'name': 'Updated Name',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': ['vision'],
|
||||
'context_length': 64000,
|
||||
'extra_args': {'temperature': 0.4},
|
||||
},
|
||||
)
|
||||
|
||||
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
|
||||
assert runtime_entity.uuid == 'existing-uuid'
|
||||
assert runtime_entity.context_length == 64000
|
||||
assert runtime_entity.extra_args == {'temperature': 0.4}
|
||||
assert 'context_length' not in runtime_entity.extra_args
|
||||
|
||||
|
||||
class TestLLMModelsServiceDeleteLLMModel:
|
||||
"""Tests for LLMModelsService.delete_llm_model method."""
|
||||
@@ -961,4 +1031,4 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
|
||||
result = await service.get_rerank_models_by_provider('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert len(result) == 2
|
||||
|
||||
@@ -896,6 +896,121 @@ class TestScanModels:
|
||||
assert by_id['text-embedding-3-small']['type'] == 'embedding'
|
||||
assert by_id['bge-reranker-v2']['type'] == 'rerank'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_prefers_context_length_from_provider_payload(self):
|
||||
"""Provider-supplied context_length is preserved before LiteLLM metadata fallback."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'timeout': 60,
|
||||
},
|
||||
)
|
||||
requester._supports_function_calling = Mock(return_value=False)
|
||||
requester._supports_vision = Mock(return_value=False)
|
||||
requester._safe_context_length = Mock(return_value=None)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json = Mock(
|
||||
return_value={
|
||||
'data': [
|
||||
{'id': 'moonshot-v1-128k', 'context_length': 131072},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await requester.scan_models(api_key='test-key')
|
||||
|
||||
assert result['models'][0]['context_length'] == 131072
|
||||
requester._safe_context_length.assert_not_called()
|
||||
|
||||
def test_safe_context_length_tries_moonshot_metadata_alias(self):
|
||||
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for context windows."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'custom_llm_provider': 'openai',
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(litellmchat.litellm, 'get_max_tokens') as mock_get_max_tokens:
|
||||
mock_get_max_tokens.side_effect = lambda model: 131072 if model == 'moonshot/moonshot-v1-128k' else None
|
||||
|
||||
assert requester._safe_context_length('moonshot-v1-128k') == 131072
|
||||
|
||||
def test_litellm_bool_helper_tries_moonshot_metadata_alias(self):
|
||||
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for abilities."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'custom_llm_provider': 'openai',
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling:
|
||||
mock_supports_function_calling.side_effect = (
|
||||
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6'
|
||||
and custom_llm_provider is None
|
||||
)
|
||||
|
||||
assert requester._supports_function_calling('kimi-k2.6') is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_uses_provider_payload_for_vision_ability(self):
|
||||
"""Provider-supplied vision support is used when scanning models."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'timeout': 60,
|
||||
},
|
||||
)
|
||||
requester._supports_function_calling = Mock(return_value=False)
|
||||
requester._supports_vision = Mock(return_value=False)
|
||||
requester._safe_context_length = Mock(return_value=None)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json = Mock(
|
||||
return_value={
|
||||
'data': [
|
||||
{
|
||||
'id': 'moonshot-v1-128k-vision-preview',
|
||||
'supports_image_in': True,
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await requester.scan_models(api_key='test-key')
|
||||
|
||||
assert result['models'][0]['abilities'] == ['vision']
|
||||
|
||||
def test_safe_context_length_falls_back_for_deepseek_v4_models(self):
|
||||
"""DeepSeek V4 API ids have a known 1M context even before LiteLLM maps them."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.deepseek.com',
|
||||
'custom_llm_provider': 'deepseek',
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(litellmchat.litellm, 'get_max_tokens', side_effect=Exception('not mapped')):
|
||||
assert requester._safe_context_length('deepseek-v4-pro') == 1_000_000
|
||||
assert requester._safe_context_length('deepseek-v4-flash') == 1_000_000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_no_base_url(self):
|
||||
"""Test scan_models without base_url raises error"""
|
||||
|
||||
@@ -494,6 +494,7 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
|
||||
'api_keys': ['temp-key'],
|
||||
},
|
||||
'abilities': ['func_call'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temperature': 0.5},
|
||||
}
|
||||
|
||||
@@ -501,6 +502,9 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
|
||||
|
||||
assert runtime_model.model_entity.uuid == 'temp-model-uuid'
|
||||
assert runtime_model.model_entity.name == 'TempModel'
|
||||
assert runtime_model.model_entity.context_length == 128000
|
||||
assert runtime_model.model_entity.extra_args == {'temperature': 0.5}
|
||||
assert 'context_length' not in runtime_model.model_entity.extra_args
|
||||
assert runtime_model.provider.provider_entity.uuid == 'temp-provider-uuid'
|
||||
assert runtime_model.provider.token_mgr.tokens == ['temp-key']
|
||||
|
||||
@@ -785,4 +789,4 @@ def test_provider_not_found_error_str():
|
||||
error = provider_errors.ProviderNotFoundError('test-provider')
|
||||
|
||||
assert str(error) == 'Provider test-provider not found'
|
||||
assert error.provider_name == 'test-provider'
|
||||
assert error.provider_name == 'test-provider'
|
||||
|
||||
Reference in New Issue
Block a user