fix(provider): use LiteLLM input window for context length (#2243)

This commit is contained in:
huanghuoguoguo
2026-06-13 21:27:47 +08:00
committed by GitHub
parent b6fde30aa7
commit 7fe3eedeea
2 changed files with 45 additions and 16 deletions

View File

@@ -75,22 +75,33 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
continue
return False
@staticmethod
def _positive_int(value: typing.Any) -> int | None:
if isinstance(value, bool):
return None
if isinstance(value, int) and value > 0:
return value
if isinstance(value, str) and value.isdigit():
parsed_value = int(value)
if parsed_value > 0:
return parsed_value
return None
def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None:
if not model_payload:
return None
for field_name in ('context_length', 'context_window', 'max_context_length'):
value = model_payload.get(field_name)
if isinstance(value, bool):
continue
if isinstance(value, int) and value > 0:
return value
if isinstance(value, str) and value.isdigit():
parsed_value = int(value)
if parsed_value > 0:
return parsed_value
context_length = self._positive_int(model_payload.get(field_name))
if context_length is not None:
return context_length
return None
def _context_length_from_litellm_model_info(self, model_info: typing.Any) -> int | None:
if isinstance(model_info, dict):
return self._positive_int(model_info.get('max_input_tokens'))
return self._positive_int(getattr(model_info, 'max_input_tokens', None))
def _metadata_provider_candidates(self, model_name: str) -> list[str]:
normalized_model_name = (model_name or '').lower()
candidates = []
@@ -126,7 +137,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
return None
def _safe_context_length(self, model_name: str) -> int | None:
helper = getattr(litellm, 'get_max_tokens', None)
helper = getattr(litellm, 'get_model_info', None)
if not callable(helper):
return self._known_context_length_fallback(model_name)
@@ -143,11 +154,12 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
continue
tried_candidates.append(candidate)
try:
max_tokens = helper(candidate)
model_info = helper(candidate)
except Exception:
continue
if isinstance(max_tokens, int) and max_tokens > 0:
return max_tokens
context_length = self._context_length_from_litellm_model_info(model_info)
if context_length is not None:
return context_length
return self._known_context_length_fallback(model_name)
def _supports_function_calling(self, model_name: str) -> bool:

View File

@@ -1034,11 +1034,28 @@ class TestScanModels:
},
)
with patch.object(litellmchat.litellm, 'get_max_tokens') as mock_get_max_tokens:
mock_get_max_tokens.side_effect = lambda model: 131072 if model == 'moonshot/moonshot-v1-128k' else None
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
mock_get_model_info.side_effect = (
lambda model: {'max_input_tokens': 131072}
if model == 'moonshot/moonshot-v1-128k'
else {}
)
assert requester._safe_context_length('moonshot-v1-128k') == 131072
def test_safe_context_length_uses_litellm_max_input_tokens(self):
"""LiteLLM max_output_tokens must not be treated as the context window."""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
with patch.object(litellmchat.litellm, 'get_model_info') as mock_get_model_info:
mock_get_model_info.return_value = {
'max_input_tokens': 128000,
'max_output_tokens': 16384,
'max_tokens': 16384,
}
assert requester._safe_context_length('gpt-4o') == 128000
def test_litellm_bool_helper_tries_moonshot_metadata_alias(self):
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for abilities."""
requester = litellmchat.LiteLLMRequester(
@@ -1102,7 +1119,7 @@ class TestScanModels:
},
)
with patch.object(litellmchat.litellm, 'get_max_tokens', side_effect=Exception('not mapped')):
with patch.object(litellmchat.litellm, 'get_model_info', side_effect=Exception('not mapped')):
assert requester._safe_context_length('deepseek-v4-pro') == 1_000_000
assert requester._safe_context_length('deepseek-v4-flash') == 1_000_000