fix(provider): preserve litellm usage details (#2246)

This commit is contained in:
huanghuoguoguo
2026-06-14 11:12:29 +08:00
committed by GitHub
parent 1ef4507d9a
commit 27be09ab15
3 changed files with 117 additions and 27 deletions
+32 -7
View File
@@ -115,6 +115,15 @@ class TestExtractUsage:
assert result['prompt_tokens'] == 0
assert result['completion_tokens'] == 0
def test_extract_usage_without_provider_usage(self):
"""Missing provider usage is not treated as authoritative zero usage."""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
response = Mock()
response.usage = None
assert requester._extract_usage(response) is None
class TestNormalizeUsage:
"""Test _normalize_usage helper covering real-world usage shapes"""
@@ -131,6 +140,22 @@ class TestNormalizeUsage:
)
assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20}
def test_preserves_token_details(self):
"""Provider token details such as cache counters are preserved."""
result = litellmchat.LiteLLMRequester._normalize_usage(
{
'prompt_tokens': 12,
'completion_tokens': 8,
'total_tokens': 20,
'prompt_tokens_details': {'cached_tokens': 7},
'completion_tokens_details': {'reasoning_tokens': 3},
}
)
assert result['prompt_tokens'] == 12
assert result['prompt_tokens_details'] == {'cached_tokens': 7}
assert result['completion_tokens_details'] == {'reasoning_tokens': 3}
def test_missing_total_is_derived(self):
"""When total_tokens is absent/zero it is derived from prompt + completion"""
usage = Mock()
@@ -166,9 +191,7 @@ class TestInvokeLLMStreamUsage:
if has_choice:
choice = Mock()
delta = Mock()
delta.model_dump = Mock(
return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls}
)
delta.model_dump = Mock(return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls})
choice.delta = delta
choice.finish_reason = finish_reason
chunk.choices = [choice]
@@ -313,7 +336,8 @@ class TestInvokeLLMStreamUsage:
with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())):
collected = [
chunk async for chunk in requester.invoke_llm_stream(
chunk
async for chunk in requester.invoke_llm_stream(
query=query,
model=model,
messages=messages,
@@ -788,7 +812,9 @@ class TestInvokeRerank:
with patch('httpx.AsyncClient', return_value=mock_client):
# arerank must NOT be called on the openai-compatible path
with patch.object(
litellmchat, 'arerank', new_callable=AsyncMock,
litellmchat,
'arerank',
new_callable=AsyncMock,
side_effect=AssertionError('arerank must not be used for openai-compatible provider'),
):
results = await requester.invoke_rerank(
@@ -1068,8 +1094,7 @@ class TestScanModels:
with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling:
mock_supports_function_calling.side_effect = (
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6'
and custom_llm_provider is None
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6' and custom_llm_provider is None
)
assert requester._supports_function_calling('kimi-k2.6') is True