From 27be09ab15287e31f1aab3ecf81bf6b80ac7e359 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Sun, 14 Jun 2026 11:12:29 +0800 Subject: [PATCH] fix(provider): preserve litellm usage details (#2246) --- .../pkg/provider/modelmgr/requester.py | 21 ++++- .../modelmgr/requesters/litellmchat.py | 84 +++++++++++++++---- tests/unit_tests/provider/test_litellmchat.py | 39 +++++++-- 3 files changed, 117 insertions(+), 27 deletions(-) diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index b673c758..377f7d4a 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -12,6 +12,19 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.provider.message as provider_message +LLM_USAGE_QUERY_VARIABLE = '_llm_usage' +STREAM_USAGE_QUERY_VARIABLE = '_stream_usage' + + +def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None: + """Store the latest provider usage on the query for upstream action handlers.""" + if query is None or not usage_info: + return + if query.variables is None: + query.variables = {} + query.variables[LLM_USAGE_QUERY_VARIABLE] = dict(usage_info) + + class RuntimeProvider: """运行时模型提供商""" @@ -67,6 +80,7 @@ class RuntimeProvider: if isinstance(result, tuple): msg, usage_info = result if usage_info: + _store_llm_usage(query, usage_info) input_tokens = usage_info.get('prompt_tokens', 0) output_tokens = usage_info.get('completion_tokens', 0) return msg @@ -146,11 +160,12 @@ class RuntimeProvider: if query: if query.variables is None: query.variables = {} - if '_stream_usage' in query.variables: - usage_info = query.variables['_stream_usage'] + if STREAM_USAGE_QUERY_VARIABLE in query.variables: + usage_info = query.variables[STREAM_USAGE_QUERY_VARIABLE] + _store_llm_usage(query, usage_info) input_tokens = usage_info.get('prompt_tokens', 0) output_tokens = usage_info.get('completion_tokens', 0) - del query.variables['_stream_usage'] + del query.variables[STREAM_USAGE_QUERY_VARIABLE] except Exception as e: status = 'error' error_message = str(e) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index 63536a2c..8c750bd7 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -262,32 +262,82 @@ class LiteLLMRequester(requester.ProviderAPIRequester): - dict with the same keys - missing ``total_tokens`` (derived from prompt + completion) - ``None`` / partially-populated usage (defaults to 0) + - provider-specific token details, including cache token counters """ - if usage is None: - return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0} - def _get(key: str) -> typing.Any: - if isinstance(usage, dict): - return usage.get(key) - return getattr(usage, key, None) + def _plain_value(value: typing.Any) -> typing.Any: + if value is None: + return None + if isinstance(value, dict): + return {k: _plain_value(v) for k, v in value.items() if v is not None} + if isinstance(value, (list, tuple)): + return [_plain_value(v) for v in value] - prompt_tokens = _get('prompt_tokens') or 0 - completion_tokens = _get('completion_tokens') or 0 - total_tokens = _get('total_tokens') or 0 + model_dump = getattr(value, 'model_dump', None) + if callable(model_dump): + try: + dumped = model_dump() + if isinstance(dumped, dict): + return _plain_value(dumped) + except Exception: + pass + + return value + + def _usage_dict(value: typing.Any) -> dict[str, typing.Any]: + if value is None: + return {} + plain = _plain_value(value) + if isinstance(plain, dict): + return plain + + def _is_mock_attr(attr: typing.Any) -> bool: + return type(attr).__module__.startswith('unittest.mock') + + data: dict[str, typing.Any] = {} + for key in ( + 'prompt_tokens', + 'completion_tokens', + 'total_tokens', + 'prompt_tokens_details', + 'completion_tokens_details', + 'cache_creation_input_tokens', + 'cache_read_input_tokens', + 'input_token_details', + 'output_token_details', + ): + attr_value = getattr(value, key, None) + if attr_value is not None and not _is_mock_attr(attr_value): + data[key] = _plain_value(attr_value) + return data + + def _to_int(value: typing.Any) -> int: + try: + return int(value or 0) + except (TypeError, ValueError): + return 0 + + normalized = _usage_dict(usage) + + prompt_tokens = _to_int(normalized.get('prompt_tokens')) + completion_tokens = _to_int(normalized.get('completion_tokens')) + total_tokens = _to_int(normalized.get('total_tokens')) # Some providers omit total_tokens in streaming usage; derive it. if not total_tokens: total_tokens = prompt_tokens + completion_tokens - return { - 'prompt_tokens': int(prompt_tokens), - 'completion_tokens': int(completion_tokens), - 'total_tokens': int(total_tokens), - } + normalized['prompt_tokens'] = prompt_tokens + normalized['completion_tokens'] = completion_tokens + normalized['total_tokens'] = total_tokens + return normalized - def _extract_usage(self, response) -> dict: + def _extract_usage(self, response) -> dict | None: """Extract usage info from a non-streaming LiteLLM response.""" - return self._normalize_usage(getattr(response, 'usage', None)) + usage = getattr(response, 'usage', None) + if usage is None: + return None + return self._normalize_usage(usage) @staticmethod def _as_dict(value: typing.Any) -> dict: @@ -486,7 +536,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): if query is not None: if query.variables is None: query.variables = {} - query.variables['_stream_usage'] = usage_info + query.variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage_info if not hasattr(chunk, 'choices') or not chunk.choices: continue diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index 2878d827..abe0cf49 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -115,6 +115,15 @@ class TestExtractUsage: assert result['prompt_tokens'] == 0 assert result['completion_tokens'] == 0 + def test_extract_usage_without_provider_usage(self): + """Missing provider usage is not treated as authoritative zero usage.""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + response = Mock() + response.usage = None + + assert requester._extract_usage(response) is None + class TestNormalizeUsage: """Test _normalize_usage helper covering real-world usage shapes""" @@ -131,6 +140,22 @@ class TestNormalizeUsage: ) assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20} + def test_preserves_token_details(self): + """Provider token details such as cache counters are preserved.""" + result = litellmchat.LiteLLMRequester._normalize_usage( + { + 'prompt_tokens': 12, + 'completion_tokens': 8, + 'total_tokens': 20, + 'prompt_tokens_details': {'cached_tokens': 7}, + 'completion_tokens_details': {'reasoning_tokens': 3}, + } + ) + + assert result['prompt_tokens'] == 12 + assert result['prompt_tokens_details'] == {'cached_tokens': 7} + assert result['completion_tokens_details'] == {'reasoning_tokens': 3} + def test_missing_total_is_derived(self): """When total_tokens is absent/zero it is derived from prompt + completion""" usage = Mock() @@ -166,9 +191,7 @@ class TestInvokeLLMStreamUsage: if has_choice: choice = Mock() delta = Mock() - delta.model_dump = Mock( - return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls} - ) + delta.model_dump = Mock(return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls}) choice.delta = delta choice.finish_reason = finish_reason chunk.choices = [choice] @@ -313,7 +336,8 @@ class TestInvokeLLMStreamUsage: with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())): collected = [ - chunk async for chunk in requester.invoke_llm_stream( + chunk + async for chunk in requester.invoke_llm_stream( query=query, model=model, messages=messages, @@ -788,7 +812,9 @@ class TestInvokeRerank: with patch('httpx.AsyncClient', return_value=mock_client): # arerank must NOT be called on the openai-compatible path with patch.object( - litellmchat, 'arerank', new_callable=AsyncMock, + litellmchat, + 'arerank', + new_callable=AsyncMock, side_effect=AssertionError('arerank must not be used for openai-compatible provider'), ): results = await requester.invoke_rerank( @@ -1068,8 +1094,7 @@ class TestScanModels: with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling: mock_supports_function_calling.side_effect = ( - lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6' - and custom_llm_provider is None + lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6' and custom_llm_provider is None ) assert requester._supports_function_calling('kimi-k2.6') is True