diff --git a/src/langbot/pkg/utils/runner.py b/src/langbot/pkg/utils/runner.py index 43aecc06..03628db1 100644 --- a/src/langbot/pkg/utils/runner.py +++ b/src/langbot/pkg/utils/runner.py @@ -1,5 +1,7 @@ from __future__ import annotations +import ipaddress +import re from urllib.parse import urlparse @@ -44,6 +46,40 @@ LOCAL_PATTERNS = [ '172.31.', ] +HOST_LABEL_PATTERN = re.compile(r'^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$') + + +def _is_valid_hostname(host: str) -> bool: + if host == 'localhost': + return True + + try: + ipaddress.ip_address(host) + return True + except ValueError: + pass + + if not host or len(host) > 253 or any(char.isspace() for char in host): + return False + + host = host.rstrip('.') + if not host: + return False + + return all(HOST_LABEL_PATTERN.match(label) for label in host.split('.')) + + +def _is_local_host(host: str) -> bool: + if host == 'localhost': + return True + + try: + ip_address = ipaddress.ip_address(host) + except ValueError: + return False + + return ip_address.is_private or ip_address.is_loopback or ip_address.is_unspecified + def get_runner_category(runner_name: str, runner_url: str) -> str: if not runner_url: @@ -52,12 +88,15 @@ def get_runner_category(runner_name: str, runner_url: str) -> str: try: parsed_url = urlparse(runner_url) host = parsed_url.hostname.lower() if parsed_url.hostname else '' - except Exception: + _ = parsed_url.port + except (TypeError, ValueError): return RunnerCategory.UNKNOWN - for pattern in LOCAL_PATTERNS: - if host.startswith(pattern): - return RunnerCategory.LOCAL + if not parsed_url.scheme or not host or not _is_valid_hostname(host): + return RunnerCategory.UNKNOWN + + if _is_local_host(host): + return RunnerCategory.LOCAL for domain in CLOUD_DOMAINS: if host.endswith(domain): diff --git a/tests/unit_tests/utils/test_runner.py b/tests/unit_tests/utils/test_runner.py index 155f6e77..c71dc793 100644 --- a/tests/unit_tests/utils/test_runner.py +++ b/tests/unit_tests/utils/test_runner.py @@ -99,6 +99,49 @@ class TestGetRunnerCategory: result = runner.get_runner_category("test", "http://example.com") assert result == RunnerCategory.UNKNOWN + @pytest.mark.parametrize( + "runner_url", + [ + "api.dify.ai/v1", + "localhost:7860", + "https:///v1", + "https://", + "https://exa mple.com", + "http://[::1", + "http://localhost:bad", + ], + ) + def test_invalid_urls_return_unknown(self, runner_url): + """Invalid or scheme-less URLs should not default to CLOUD.""" + assert get_runner_category("test", runner_url) == RunnerCategory.UNKNOWN + + @pytest.mark.parametrize( + "runner_url", + [ + "http://localhost:7860", + "http://127.0.0.1:7860", + "http://10.0.0.1:7860", + "http://172.16.0.1:7860", + "http://172.31.255.255:7860", + "http://192.168.1.20:7860", + "http://[::1]:7860", + ], + ) + def test_local_hosts_are_detected_with_ipaddress(self, runner_url): + """Loopback/private IP addresses and localhost should be LOCAL.""" + assert get_runner_category("test", runner_url) == RunnerCategory.LOCAL + + @pytest.mark.parametrize( + "runner_url", + [ + "http://10.evil.com", + "http://192.168.example.com", + ], + ) + def test_private_ip_prefix_domains_are_not_local(self, runner_url): + """Domain names that only look like private IP prefixes should not be LOCAL.""" + assert get_runner_category("test", runner_url) == RunnerCategory.CLOUD + class TestIsCloudRunner: """Test is_cloud_runner helper function."""