diff --git a/src/langbot/pkg/api/http/service/apikey.py b/src/langbot/pkg/api/http/service/apikey.py index c46b5608..5e6ff15d 100644 --- a/src/langbot/pkg/api/http/service/apikey.py +++ b/src/langbot/pkg/api/http/service/apikey.py @@ -52,6 +52,9 @@ class ApiKeyService: async def verify_api_key(self, key: str) -> bool: """Verify if an API key is valid""" + if not isinstance(key, str) or not key.startswith('lbk_'): + return False + result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(apikey.ApiKey).where(apikey.ApiKey.key == key) ) diff --git a/tests/unit_tests/api/test_apikey_service.py b/tests/unit_tests/api/test_apikey_service.py new file mode 100644 index 00000000..67b6737b --- /dev/null +++ b/tests/unit_tests/api/test_apikey_service.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import pytest + +from langbot.pkg.api.http.service.apikey import ApiKeyService + + +@pytest.mark.asyncio +@pytest.mark.parametrize('api_key', [None, 123, b'lbk_bytes', '', 'plain_key', ' LBK_bad', 'sk-lbk_fake']) +async def test_verify_api_key_rejects_non_lbk_keys_without_db_query(api_key): + persistence_mgr = SimpleNamespace(execute_async=AsyncMock()) + service = ApiKeyService(SimpleNamespace(persistence_mgr=persistence_mgr)) + + result = await service.verify_api_key(api_key) + + assert result is False + persistence_mgr.execute_async.assert_not_awaited() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('db_row', 'expected'), + [ + (object(), True), + (None, False), + ], +) +async def test_verify_api_key_keeps_db_validation_for_lbk_keys(db_row, expected): + query_result = Mock() + query_result.first.return_value = db_row + persistence_mgr = SimpleNamespace(execute_async=AsyncMock(return_value=query_result)) + service = ApiKeyService(SimpleNamespace(persistence_mgr=persistence_mgr)) + + result = await service.verify_api_key('lbk_valid_format') + + assert result is expected + persistence_mgr.execute_async.assert_awaited_once()