refactor(provider): simplify LiteLLM requester usage handling

- Remove unused Anthropic-specific tool schema generation
  - Share completion argument construction between normal and streaming calls
  - Use LiteLLM/OpenAI native usage fields for monitoring
  - Collect stream token usage from LiteLLM stream_options
  - Update LiteLLM requester tests for unified usage fields
This commit is contained in:
huanghuoguoguo
2026-04-25 09:22:37 +08:00
parent 31ad85517b
commit 7ea1ce2fd3
5 changed files with 59 additions and 102 deletions

View File

@@ -67,8 +67,8 @@ class RuntimeProvider:
if isinstance(result, tuple):
msg, usage_info = result
if usage_info:
input_tokens = usage_info.get('input_tokens', 0)
output_tokens = usage_info.get('output_tokens', 0)
input_tokens = usage_info.get('prompt_tokens', 0)
output_tokens = usage_info.get('completion_tokens', 0)
return msg
else:
return result
@@ -128,7 +128,6 @@ class RuntimeProvider:
start_time = time.time()
status = 'success'
error_message = None
# Note: Stream doesn't easily provide token counts, set to 0
input_tokens = 0
output_tokens = 0
@@ -143,6 +142,15 @@ class RuntimeProvider:
remove_think=remove_think,
):
yield chunk
# Extract usage from stream if available (stored by LiteLLM requester)
if query:
if query.variables is None:
query.variables = {}
if '_stream_usage' in query.variables:
usage_info = query.variables['_stream_usage']
input_tokens = usage_info.get('prompt_tokens', 0)
output_tokens = usage_info.get('completion_tokens', 0)
del query.variables['_stream_usage']
except Exception as e:
status = 'error'
error_message = str(e)

View File

@@ -88,17 +88,11 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
def _extract_usage(self, response) -> dict:
"""Extract usage info from LiteLLM response."""
usage = response.usage
usage_info = {
return {
'prompt_tokens': usage.prompt_tokens or 0,
'completion_tokens': usage.completion_tokens or 0,
'total_tokens': usage.total_tokens or 0,
}
# TODO: LangBot internal inconsistency - LLM monitoring uses input_tokens/output_tokens,
# while embedding monitoring uses prompt_tokens. Should unify in requester.py to use
# prompt_tokens (OpenAI native) consistently. After that, remove these compatibility aliases.
usage_info['input_tokens'] = usage_info['prompt_tokens']
usage_info['output_tokens'] = usage_info['completion_tokens']
return usage_info
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
"""Apply common requester config to args dict."""
@@ -136,6 +130,37 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
raise errors.RequesterError(f'API 错误: {str(e)}')
raise errors.RequesterError(f'未知错误: {str(e)}')
async def _build_completion_args(
self,
model: requester.RuntimeLLMModel,
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
stream: bool = False,
) -> dict:
"""Build common completion arguments for invoke_llm and invoke_llm_stream."""
req_messages = self._convert_messages(messages)
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
args = {
'model': model_name,
'messages': req_messages,
'api_key': api_key,
}
if stream:
args['stream'] = True
args['stream_options'] = {'include_usage': True}
self._build_common_args(args)
args.update(extra_args)
if funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
if tools:
args['tools'] = tools
return args
async def invoke_llm(
self,
query: pipeline_query.Query,
@@ -146,25 +171,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
remove_think: bool = False,
) -> tuple[provider_message.Message, dict]:
"""Invoke LLM and return message with usage info."""
# DO NOT modify input messages - copy them
req_messages = self._convert_messages(messages)
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
args = {
'model': model_name,
'messages': req_messages,
'api_key': api_key,
}
self._build_common_args(args)
args.update(extra_args)
if funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
if tools:
args['tools'] = tools
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False)
try:
response = await acompletion(**args)
@@ -198,24 +205,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
remove_think: bool = False,
) -> provider_message.MessageChunk:
"""Invoke LLM streaming and yield chunks."""
req_messages = self._convert_messages(messages)
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
args = {
'model': model_name,
'messages': req_messages,
'api_key': api_key,
'stream': True,
}
self._build_common_args(args)
args.update(extra_args)
if funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
if tools:
args['tools'] = tools
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True)
chunk_idx = 0
role = 'assistant'
@@ -223,6 +213,19 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
try:
response = await acompletion(**args)
async for chunk in response:
# Check for usage chunk (final chunk with stream_options include_usage)
if hasattr(chunk, 'usage') and chunk.usage and (not hasattr(chunk, 'choices') or not chunk.choices):
usage_info = {
'prompt_tokens': chunk.usage.prompt_tokens or 0,
'completion_tokens': chunk.usage.completion_tokens or 0,
'total_tokens': chunk.usage.total_tokens or 0,
}
if query:
if query.variables is None:
query.variables = {}
query.variables['_stream_usage'] = usage_info
continue
if not hasattr(chunk, 'choices') or not chunk.choices:
continue

View File

@@ -83,19 +83,6 @@ class ToolManager:
return tools
async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list:
tools = []
for function in use_funcs:
function_schema = {
'name': function.name,
'description': function.description,
'input_schema': function.parameters,
}
tools.append(function_schema)
return tools
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
if await self.native_tool_loader.has_tool(name):
return await self.native_tool_loader.invoke_tool(name, parameters, query)

View File

@@ -93,8 +93,6 @@ class TestExtractUsage:
assert result['prompt_tokens'] == 100
assert result['completion_tokens'] == 50
assert result['total_tokens'] == 150
assert result['input_tokens'] == 100 # Compatibility alias
assert result['output_tokens'] == 50 # Compatibility alias
def test_extract_usage_with_zero_values(self):
"""Test extraction when values are 0"""

View File

@@ -1,7 +1,7 @@
"""Unit tests for ToolManager.
Tests cover:
- Tool schema generation for OpenAI and Anthropic
- Tool schema generation for OpenAI/LiteLLM
- Tool execution dispatch
"""
@@ -109,28 +109,6 @@ class TestToolManagerSchemaGeneration:
assert tool2['type'] == 'function'
assert tool2['function']['name'] == 'calculate'
@pytest.mark.asyncio
async def test_generate_tools_for_anthropic(self, mock_app, sample_tools):
"""Test that generate_tools_for_anthropic produces correct schema."""
toolmgr = get_toolmgr_module()
manager = toolmgr.ToolManager(mock_app)
result = await manager.generate_tools_for_anthropic(sample_tools)
assert len(result) == 2
# Verify first tool schema (Anthropic format)
tool1 = result[0]
assert tool1['name'] == 'get_weather'
assert tool1['description'] == 'Get current weather for a location'
assert 'input_schema' in tool1
assert tool1['input_schema']['type'] == 'object'
# Verify second tool schema
tool2 = result[1]
assert tool2['name'] == 'calculate'
assert 'input_schema' in tool2
@pytest.mark.asyncio
async def test_generate_tools_empty_list(self, mock_app):
"""Test that generating tools from empty list returns empty list."""
@@ -141,9 +119,6 @@ class TestToolManagerSchemaGeneration:
openai_result = await manager.generate_tools_for_openai([])
assert openai_result == []
anthropic_result = await manager.generate_tools_for_anthropic([])
assert anthropic_result == []
@pytest.mark.asyncio
async def test_openai_schema_fields_complete(self, mock_app, sample_tools):
"""Test that OpenAI schema includes all required fields."""
@@ -161,20 +136,6 @@ class TestToolManagerSchemaGeneration:
assert 'description' in func
assert 'parameters' in func
@pytest.mark.asyncio
async def test_anthropic_schema_fields_complete(self, mock_app, sample_tools):
"""Test that Anthropic schema includes all required fields."""
toolmgr = get_toolmgr_module()
manager = toolmgr.ToolManager(mock_app)
result = await manager.generate_tools_for_anthropic(sample_tools)
for tool_schema in result:
assert 'name' in tool_schema
assert 'description' in tool_schema
assert 'input_schema' in tool_schema
class TestToolManagerExecuteFuncCall:
"""Tests for execute_func_call method."""