From cff9ac568352d0753d7fa4a52c7006586e3bff7f Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <60681390+huanghuoguoguo@users.noreply.github.com> Date: Sun, 14 Jun 2026 11:04:34 +0800 Subject: [PATCH] chore(agent-runner): split litellm usage details --- .../pkg/provider/modelmgr/requester.py | 21 +---- .../modelmgr/requesters/litellmchat.py | 83 ++++--------------- tests/unit_tests/provider/test_litellmchat.py | 25 ------ 3 files changed, 20 insertions(+), 109 deletions(-) diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index 377f7d4a..b673c758 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -12,19 +12,6 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.provider.message as provider_message -LLM_USAGE_QUERY_VARIABLE = '_llm_usage' -STREAM_USAGE_QUERY_VARIABLE = '_stream_usage' - - -def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None: - """Store the latest provider usage on the query for upstream action handlers.""" - if query is None or not usage_info: - return - if query.variables is None: - query.variables = {} - query.variables[LLM_USAGE_QUERY_VARIABLE] = dict(usage_info) - - class RuntimeProvider: """运行时模型提供商""" @@ -80,7 +67,6 @@ class RuntimeProvider: if isinstance(result, tuple): msg, usage_info = result if usage_info: - _store_llm_usage(query, usage_info) input_tokens = usage_info.get('prompt_tokens', 0) output_tokens = usage_info.get('completion_tokens', 0) return msg @@ -160,12 +146,11 @@ class RuntimeProvider: if query: if query.variables is None: query.variables = {} - if STREAM_USAGE_QUERY_VARIABLE in query.variables: - usage_info = query.variables[STREAM_USAGE_QUERY_VARIABLE] - _store_llm_usage(query, usage_info) + if '_stream_usage' in query.variables: + usage_info = query.variables['_stream_usage'] input_tokens = usage_info.get('prompt_tokens', 0) output_tokens = usage_info.get('completion_tokens', 0) - del query.variables[STREAM_USAGE_QUERY_VARIABLE] + del query.variables['_stream_usage'] except Exception as e: status = 'error' error_message = str(e) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index 6e5d53e8..6b087916 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -250,81 +250,32 @@ class LiteLLMRequester(requester.ProviderAPIRequester): - dict with the same keys - missing ``total_tokens`` (derived from prompt + completion) - ``None`` / partially-populated usage (defaults to 0) - - provider-specific token details, including cache token counters """ - def _plain_value(value: typing.Any) -> typing.Any: - if value is None: - return None - if isinstance(value, dict): - return {k: _plain_value(v) for k, v in value.items() if v is not None} - if isinstance(value, (list, tuple)): - return [_plain_value(v) for v in value] + if usage is None: + return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0} - model_dump = getattr(value, 'model_dump', None) - if callable(model_dump): - try: - dumped = model_dump() - if isinstance(dumped, dict): - return _plain_value(dumped) - except Exception: - pass + def _get(key: str) -> typing.Any: + if isinstance(usage, dict): + return usage.get(key) + return getattr(usage, key, None) - return value - - def _usage_dict(value: typing.Any) -> dict[str, typing.Any]: - if value is None: - return {} - plain = _plain_value(value) - if isinstance(plain, dict): - return plain - - def _is_mock_attr(attr: typing.Any) -> bool: - return type(attr).__module__.startswith('unittest.mock') - - data: dict[str, typing.Any] = {} - for key in ( - 'prompt_tokens', - 'completion_tokens', - 'total_tokens', - 'prompt_tokens_details', - 'completion_tokens_details', - 'cache_creation_input_tokens', - 'cache_read_input_tokens', - 'input_token_details', - 'output_token_details', - ): - attr_value = getattr(value, key, None) - if attr_value is not None and not _is_mock_attr(attr_value): - data[key] = _plain_value(attr_value) - return data - - def _to_int(value: typing.Any) -> int: - try: - return int(value or 0) - except (TypeError, ValueError): - return 0 - - normalized = _usage_dict(usage) - - prompt_tokens = _to_int(normalized.get('prompt_tokens')) - completion_tokens = _to_int(normalized.get('completion_tokens')) - total_tokens = _to_int(normalized.get('total_tokens')) + prompt_tokens = _get('prompt_tokens') or 0 + completion_tokens = _get('completion_tokens') or 0 + total_tokens = _get('total_tokens') or 0 # Some providers omit total_tokens in streaming usage; derive it. if not total_tokens: total_tokens = prompt_tokens + completion_tokens - normalized['prompt_tokens'] = prompt_tokens - normalized['completion_tokens'] = completion_tokens - normalized['total_tokens'] = total_tokens - return normalized + return { + 'prompt_tokens': int(prompt_tokens), + 'completion_tokens': int(completion_tokens), + 'total_tokens': int(total_tokens), + } - def _extract_usage(self, response) -> dict | None: + def _extract_usage(self, response) -> dict: """Extract usage info from a non-streaming LiteLLM response.""" - usage = getattr(response, 'usage', None) - if usage is None: - return None - return self._normalize_usage(usage) + return self._normalize_usage(getattr(response, 'usage', None)) @staticmethod def _as_dict(value: typing.Any) -> dict: @@ -523,7 +474,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): if query is not None: if query.variables is None: query.variables = {} - query.variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage_info + query.variables['_stream_usage'] = usage_info if not hasattr(chunk, 'choices') or not chunk.choices: continue diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index 12bc6dbb..1ec12d82 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -115,15 +115,6 @@ class TestExtractUsage: assert result['prompt_tokens'] == 0 assert result['completion_tokens'] == 0 - def test_extract_usage_without_provider_usage(self): - """Missing provider usage is not treated as authoritative zero usage.""" - requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) - - response = Mock() - response.usage = None - - assert requester._extract_usage(response) is None - class TestNormalizeUsage: """Test _normalize_usage helper covering real-world usage shapes""" @@ -140,22 +131,6 @@ class TestNormalizeUsage: ) assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20} - def test_preserves_token_details(self): - """Provider token details such as cache counters are preserved.""" - result = litellmchat.LiteLLMRequester._normalize_usage( - { - 'prompt_tokens': 12, - 'completion_tokens': 8, - 'total_tokens': 20, - 'prompt_tokens_details': {'cached_tokens': 7}, - 'completion_tokens_details': {'reasoning_tokens': 3}, - } - ) - - assert result['prompt_tokens'] == 12 - assert result['prompt_tokens_details'] == {'cached_tokens': 7} - assert result['completion_tokens_details'] == {'reasoning_tokens': 3} - def test_missing_total_is_derived(self): """When total_tokens is absent/zero it is derived from prompt + completion""" usage = Mock()