diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index 6b087916..63536a2c 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -75,22 +75,33 @@ class LiteLLMRequester(requester.ProviderAPIRequester): continue return False + @staticmethod + def _positive_int(value: typing.Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int) and value > 0: + return value + if isinstance(value, str) and value.isdigit(): + parsed_value = int(value) + if parsed_value > 0: + return parsed_value + return None + def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None: if not model_payload: return None for field_name in ('context_length', 'context_window', 'max_context_length'): - value = model_payload.get(field_name) - if isinstance(value, bool): - continue - if isinstance(value, int) and value > 0: - return value - if isinstance(value, str) and value.isdigit(): - parsed_value = int(value) - if parsed_value > 0: - return parsed_value + context_length = self._positive_int(model_payload.get(field_name)) + if context_length is not None: + return context_length return None + def _context_length_from_litellm_model_info(self, model_info: typing.Any) -> int | None: + if isinstance(model_info, dict): + return self._positive_int(model_info.get('max_input_tokens')) + return self._positive_int(getattr(model_info, 'max_input_tokens', None)) + def _metadata_provider_candidates(self, model_name: str) -> list[str]: normalized_model_name = (model_name or '').lower() candidates = [] @@ -126,7 +137,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): return None def _safe_context_length(self, model_name: str) -> int | None: - helper = getattr(litellm, 'get_max_tokens', None) + helper = getattr(litellm, 'get_model_info', None) if not callable(helper): return self._known_context_length_fallback(model_name) @@ -143,11 +154,12 @@ class LiteLLMRequester(requester.ProviderAPIRequester): continue tried_candidates.append(candidate) try: - max_tokens = helper(candidate) + model_info = helper(candidate) except Exception: continue - if isinstance(max_tokens, int) and max_tokens > 0: - return max_tokens + context_length = self._context_length_from_litellm_model_info(model_info) + if context_length is not None: + return context_length return self._known_context_length_fallback(model_name) def _supports_function_calling(self, model_name: str) -> bool: diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index 1ec12d82..2878d827 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -1034,11 +1034,28 @@ class TestScanModels: }, ) - with patch.object(litellmchat.litellm, 'get_max_tokens') as mock_get_max_tokens: - mock_get_max_tokens.side_effect = lambda model: 131072 if model == 'moonshot/moonshot-v1-128k' else None + with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info: + mock_get_model_info.side_effect = ( + lambda model: {'max_input_tokens': 131072} + if model == 'moonshot/moonshot-v1-128k' + else {} + ) assert requester._safe_context_length('moonshot-v1-128k') == 131072 + def test_safe_context_length_uses_litellm_max_input_tokens(self): + """LiteLLM max_output_tokens must not be treated as the context window.""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info: + mock_get_model_info.return_value = { + 'max_input_tokens': 128000, + 'max_output_tokens': 16384, + 'max_tokens': 16384, + } + + assert requester._safe_context_length('gpt-4o') == 128000 + def test_litellm_bool_helper_tries_moonshot_metadata_alias(self): """OpenAI-compatible Moonshot endpoints still use Moonshot metadata for abilities.""" requester = litellmchat.LiteLLMRequester( @@ -1102,7 +1119,7 @@ class TestScanModels: }, ) - with patch.object(litellmchat.litellm, 'get_max_tokens', side_effect=Exception('not mapped')): + with patch.object(litellmchat.litellm, 'get_model_info', side_effect=Exception('not mapped')): assert requester._safe_context_length('deepseek-v4-pro') == 1_000_000 assert requester._safe_context_length('deepseek-v4-flash') == 1_000_000