refactor(provider): simplify LiteLLM requester usage handling

- Remove unused Anthropic-specific tool schema generation
  - Share completion argument construction between normal and streaming calls
  - Use LiteLLM/OpenAI native usage fields for monitoring
  - Collect stream token usage from LiteLLM stream_options
  - Update LiteLLM requester tests for unified usage fields
This commit is contained in:
huanghuoguoguo
2026-04-25 09:22:37 +08:00
parent b33d05f99a
commit d170bdd343
4 changed files with 58 additions and 84 deletions

View File

@@ -67,8 +67,8 @@ class RuntimeProvider:
if isinstance(result, tuple):
msg, usage_info = result
if usage_info:
input_tokens = usage_info.get('input_tokens', 0)
output_tokens = usage_info.get('output_tokens', 0)
input_tokens = usage_info.get('prompt_tokens', 0)
output_tokens = usage_info.get('completion_tokens', 0)
return msg
else:
return result
@@ -128,7 +128,6 @@ class RuntimeProvider:
start_time = time.time()
status = 'success'
error_message = None
# Note: Stream doesn't easily provide token counts, set to 0
input_tokens = 0
output_tokens = 0
@@ -143,6 +142,15 @@ class RuntimeProvider:
remove_think=remove_think,
):
yield chunk
# Extract usage from stream if available (stored by LiteLLM requester)
if query:
if query.variables is None:
query.variables = {}
if '_stream_usage' in query.variables:
usage_info = query.variables['_stream_usage']
input_tokens = usage_info.get('prompt_tokens', 0)
output_tokens = usage_info.get('completion_tokens', 0)
del query.variables['_stream_usage']
except Exception as e:
status = 'error'
error_message = str(e)

View File

@@ -88,17 +88,11 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
def _extract_usage(self, response) -> dict:
"""Extract usage info from LiteLLM response."""
usage = response.usage
usage_info = {
return {
'prompt_tokens': usage.prompt_tokens or 0,
'completion_tokens': usage.completion_tokens or 0,
'total_tokens': usage.total_tokens or 0,
}
# TODO: LangBot internal inconsistency - LLM monitoring uses input_tokens/output_tokens,
# while embedding monitoring uses prompt_tokens. Should unify in requester.py to use
# prompt_tokens (OpenAI native) consistently. After that, remove these compatibility aliases.
usage_info['input_tokens'] = usage_info['prompt_tokens']
usage_info['output_tokens'] = usage_info['completion_tokens']
return usage_info
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
"""Apply common requester config to args dict."""
@@ -136,6 +130,37 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
raise errors.RequesterError(f'API 错误: {str(e)}')
raise errors.RequesterError(f'未知错误: {str(e)}')
async def _build_completion_args(
self,
model: requester.RuntimeLLMModel,
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
stream: bool = False,
) -> dict:
"""Build common completion arguments for invoke_llm and invoke_llm_stream."""
req_messages = self._convert_messages(messages)
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
args = {
'model': model_name,
'messages': req_messages,
'api_key': api_key,
}
if stream:
args['stream'] = True
args['stream_options'] = {'include_usage': True}
self._build_common_args(args)
args.update(extra_args)
if funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
if tools:
args['tools'] = tools
return args
async def invoke_llm(
self,
query: pipeline_query.Query,
@@ -146,25 +171,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
remove_think: bool = False,
) -> tuple[provider_message.Message, dict]:
"""Invoke LLM and return message with usage info."""
# DO NOT modify input messages - copy them
req_messages = self._convert_messages(messages)
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
args = {
'model': model_name,
'messages': req_messages,
'api_key': api_key,
}
self._build_common_args(args)
args.update(extra_args)
if funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
if tools:
args['tools'] = tools
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False)
try:
response = await acompletion(**args)
@@ -198,24 +205,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
remove_think: bool = False,
) -> provider_message.MessageChunk:
"""Invoke LLM streaming and yield chunks."""
req_messages = self._convert_messages(messages)
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
args = {
'model': model_name,
'messages': req_messages,
'api_key': api_key,
'stream': True,
}
self._build_common_args(args)
args.update(extra_args)
if funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
if tools:
args['tools'] = tools
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True)
chunk_idx = 0
role = 'assistant'
@@ -223,6 +213,19 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
try:
response = await acompletion(**args)
async for chunk in response:
# Check for usage chunk (final chunk with stream_options include_usage)
if hasattr(chunk, 'usage') and chunk.usage and (not hasattr(chunk, 'choices') or not chunk.choices):
usage_info = {
'prompt_tokens': chunk.usage.prompt_tokens or 0,
'completion_tokens': chunk.usage.completion_tokens or 0,
'total_tokens': chunk.usage.total_tokens or 0,
}
if query:
if query.variables is None:
query.variables = {}
query.variables['_stream_usage'] = usage_info
continue
if not hasattr(chunk, 'choices') or not chunk.choices:
continue

View File

@@ -57,41 +57,6 @@ class ToolManager:
return tools
async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list:
"""为anthropic生成函数列表
e.g.
[
{
"name": "get_stock_price",
"description": "Get the current stock price for a given ticker symbol.",
"input_schema": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL for Apple Inc."
}
},
"required": ["ticker"]
}
}
]
"""
tools = []
for function in use_funcs:
function_schema = {
'name': function.name,
'description': function.description,
'input_schema': function.parameters,
}
tools.append(function_schema)
return tools
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
"""执行函数调用"""

View File

@@ -93,8 +93,6 @@ class TestExtractUsage:
assert result['prompt_tokens'] == 100
assert result['completion_tokens'] == 50
assert result['total_tokens'] == 150
assert result['input_tokens'] == 100 # Compatibility alias
assert result['output_tokens'] == 50 # Compatibility alias
def test_extract_usage_with_zero_values(self):
"""Test extraction when values are 0"""