feat(models): persist context metadata

This commit is contained in:
huanghuoguoguo
2026-06-08 00:39:30 +08:00
parent 573e1fe36e
commit b82db2b7f8
23 changed files with 498 additions and 22 deletions

View File

@@ -35,6 +35,7 @@ def _create_mock_llm_model(
name: str = 'Test LLM',
provider_uuid: str = 'provider-uuid',
abilities: list = None,
context_length: int | None = None,
extra_args: dict = None,
) -> Mock:
"""Helper to create mock LLMModel entity."""
@@ -43,6 +44,7 @@ def _create_mock_llm_model(
model.name = name
model.provider_uuid = provider_uuid
model.abilities = abilities or []
model.context_length = context_length
model.extra_args = extra_args or {}
return model
@@ -142,10 +144,12 @@ class TestRuntimeModelData:
'name': 'Model',
'provider_uuid': 'provider',
'abilities': ['vision'],
'context_length': 128000,
'extra_args': {'temp': 0.7},
}
result = _runtime_model_data('uuid', update_payload)
assert result['abilities'] == ['vision']
assert result['context_length'] == 128000
assert result['extra_args'] == {'temp': 0.7}
@@ -188,7 +192,7 @@ class TestLLMModelsServiceGetLLMModels:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_llm_model()
model = _create_mock_llm_model(context_length=128000)
provider = _create_mock_provider()
mock_model_result = _create_mock_result([model])
@@ -206,6 +210,7 @@ class TestLLMModelsServiceGetLLMModels:
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
'context_length': getattr(entity, 'context_length', None),
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
}
)
@@ -218,6 +223,7 @@ class TestLLMModelsServiceGetLLMModels:
# Verify
assert len(result) == 1
assert result[0]['name'] == 'Test LLM'
assert result[0]['context_length'] == 128000
async def test_get_llm_models_hide_secret_keys(self):
"""Hides secret API keys when include_secret=False."""
@@ -265,7 +271,7 @@ class TestLLMModelsServiceGetLLMModel:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_llm_model(model_uuid='found-uuid')
model = _create_mock_llm_model(model_uuid='found-uuid', context_length=128000)
provider = _create_mock_provider()
mock_model_result = _create_mock_result([], first_item=model)
@@ -279,11 +285,12 @@ class TestLLMModelsServiceGetLLMModel:
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'found-uuid',
'name': 'Test LLM',
'provider_uuid': 'provider-uuid',
'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']},
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': getattr(entity, 'provider_uuid', None),
'context_length': getattr(entity, 'context_length', None),
'api_keys': getattr(entity, 'api_keys', None),
}
)
@@ -295,6 +302,7 @@ class TestLLMModelsServiceGetLLMModel:
# Verify
assert result is not None
assert result['uuid'] == 'found-uuid'
assert result['context_length'] == 128000
async def test_get_llm_model_not_found(self):
"""Returns None when model not found."""
@@ -402,6 +410,39 @@ class TestLLMModelsServiceCreateLLMModel:
# Verify
assert model_uuid == 'preserved-uuid'
async def test_create_llm_model_persists_context_length_as_column(self):
"""Creates LLM model with context_length outside extra_args."""
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.llm_models = []
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock())
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = LLMModelsService(ap)
await service.create_llm_model(
{
'uuid': 'model-with-context',
'name': 'Context Model',
'provider_uuid': 'provider-uuid',
'abilities': ['func_call'],
'context_length': 128000,
'extra_args': {'temperature': 0.2},
},
preserve_uuid=True,
auto_set_to_default_pipeline=False,
)
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
assert runtime_entity.context_length == 128000
assert runtime_entity.extra_args == {'temperature': 0.2}
assert 'context_length' not in runtime_entity.extra_args
async def test_create_llm_model_provider_not_found_raises_error(self):
"""Raises Exception when provider not found in runtime."""
# Setup
@@ -512,6 +553,35 @@ class TestLLMModelsServiceUpdateLLMModel:
'provider_uuid': 'nonexistent-provider',
})
async def test_update_llm_model_reloads_context_length_as_column(self):
"""Updates runtime model with context_length outside extra_args."""
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.llm_models = []
ap.model_mgr.remove_llm_model = AsyncMock()
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
service = LLMModelsService(ap)
await service.update_llm_model(
'existing-uuid',
{
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
'abilities': ['vision'],
'context_length': 64000,
'extra_args': {'temperature': 0.4},
},
)
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
assert runtime_entity.uuid == 'existing-uuid'
assert runtime_entity.context_length == 64000
assert runtime_entity.extra_args == {'temperature': 0.4}
assert 'context_length' not in runtime_entity.extra_args
class TestLLMModelsServiceDeleteLLMModel:
"""Tests for LLMModelsService.delete_llm_model method."""
@@ -961,4 +1031,4 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
result = await service.get_rerank_models_by_provider('provider-uuid')
# Verify
assert len(result) == 2
assert len(result) == 2