mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-18 19:44:21 +00:00
refactor(provider): simplify LiteLLM requester usage handling
- Remove unused Anthropic-specific tool schema generation - Share completion argument construction between normal and streaming calls - Use LiteLLM/OpenAI native usage fields for monitoring - Collect stream token usage from LiteLLM stream_options - Update LiteLLM requester tests for unified usage fields
This commit is contained in:
@@ -67,8 +67,8 @@ class RuntimeProvider:
|
|||||||
if isinstance(result, tuple):
|
if isinstance(result, tuple):
|
||||||
msg, usage_info = result
|
msg, usage_info = result
|
||||||
if usage_info:
|
if usage_info:
|
||||||
input_tokens = usage_info.get('input_tokens', 0)
|
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||||
output_tokens = usage_info.get('output_tokens', 0)
|
output_tokens = usage_info.get('completion_tokens', 0)
|
||||||
return msg
|
return msg
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
@@ -128,7 +128,6 @@ class RuntimeProvider:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
status = 'success'
|
status = 'success'
|
||||||
error_message = None
|
error_message = None
|
||||||
# Note: Stream doesn't easily provide token counts, set to 0
|
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
|
|
||||||
@@ -143,6 +142,15 @@ class RuntimeProvider:
|
|||||||
remove_think=remove_think,
|
remove_think=remove_think,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
# Extract usage from stream if available (stored by LiteLLM requester)
|
||||||
|
if query:
|
||||||
|
if query.variables is None:
|
||||||
|
query.variables = {}
|
||||||
|
if '_stream_usage' in query.variables:
|
||||||
|
usage_info = query.variables['_stream_usage']
|
||||||
|
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||||
|
output_tokens = usage_info.get('completion_tokens', 0)
|
||||||
|
del query.variables['_stream_usage']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status = 'error'
|
status = 'error'
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
|
|||||||
@@ -88,17 +88,11 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
|||||||
def _extract_usage(self, response) -> dict:
|
def _extract_usage(self, response) -> dict:
|
||||||
"""Extract usage info from LiteLLM response."""
|
"""Extract usage info from LiteLLM response."""
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
usage_info = {
|
return {
|
||||||
'prompt_tokens': usage.prompt_tokens or 0,
|
'prompt_tokens': usage.prompt_tokens or 0,
|
||||||
'completion_tokens': usage.completion_tokens or 0,
|
'completion_tokens': usage.completion_tokens or 0,
|
||||||
'total_tokens': usage.total_tokens or 0,
|
'total_tokens': usage.total_tokens or 0,
|
||||||
}
|
}
|
||||||
# TODO: LangBot internal inconsistency - LLM monitoring uses input_tokens/output_tokens,
|
|
||||||
# while embedding monitoring uses prompt_tokens. Should unify in requester.py to use
|
|
||||||
# prompt_tokens (OpenAI native) consistently. After that, remove these compatibility aliases.
|
|
||||||
usage_info['input_tokens'] = usage_info['prompt_tokens']
|
|
||||||
usage_info['output_tokens'] = usage_info['completion_tokens']
|
|
||||||
return usage_info
|
|
||||||
|
|
||||||
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
|
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
|
||||||
"""Apply common requester config to args dict."""
|
"""Apply common requester config to args dict."""
|
||||||
@@ -136,6 +130,37 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
|||||||
raise errors.RequesterError(f'API 错误: {str(e)}')
|
raise errors.RequesterError(f'API 错误: {str(e)}')
|
||||||
raise errors.RequesterError(f'未知错误: {str(e)}')
|
raise errors.RequesterError(f'未知错误: {str(e)}')
|
||||||
|
|
||||||
|
async def _build_completion_args(
|
||||||
|
self,
|
||||||
|
model: requester.RuntimeLLMModel,
|
||||||
|
messages: typing.List[provider_message.Message],
|
||||||
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
stream: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Build common completion arguments for invoke_llm and invoke_llm_stream."""
|
||||||
|
req_messages = self._convert_messages(messages)
|
||||||
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
||||||
|
api_key = model.provider.token_mgr.get_token()
|
||||||
|
|
||||||
|
args = {
|
||||||
|
'model': model_name,
|
||||||
|
'messages': req_messages,
|
||||||
|
'api_key': api_key,
|
||||||
|
}
|
||||||
|
if stream:
|
||||||
|
args['stream'] = True
|
||||||
|
args['stream_options'] = {'include_usage': True}
|
||||||
|
self._build_common_args(args)
|
||||||
|
args.update(extra_args)
|
||||||
|
|
||||||
|
if funcs:
|
||||||
|
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
||||||
|
if tools:
|
||||||
|
args['tools'] = tools
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
async def invoke_llm(
|
async def invoke_llm(
|
||||||
self,
|
self,
|
||||||
query: pipeline_query.Query,
|
query: pipeline_query.Query,
|
||||||
@@ -146,25 +171,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
|||||||
remove_think: bool = False,
|
remove_think: bool = False,
|
||||||
) -> tuple[provider_message.Message, dict]:
|
) -> tuple[provider_message.Message, dict]:
|
||||||
"""Invoke LLM and return message with usage info."""
|
"""Invoke LLM and return message with usage info."""
|
||||||
# DO NOT modify input messages - copy them
|
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False)
|
||||||
req_messages = self._convert_messages(messages)
|
|
||||||
|
|
||||||
model_name = self._build_litellm_model_name(model.model_entity.name)
|
|
||||||
|
|
||||||
api_key = model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {
|
|
||||||
'model': model_name,
|
|
||||||
'messages': req_messages,
|
|
||||||
'api_key': api_key,
|
|
||||||
}
|
|
||||||
self._build_common_args(args)
|
|
||||||
args.update(extra_args)
|
|
||||||
|
|
||||||
if funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await acompletion(**args)
|
response = await acompletion(**args)
|
||||||
@@ -198,24 +205,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
|||||||
remove_think: bool = False,
|
remove_think: bool = False,
|
||||||
) -> provider_message.MessageChunk:
|
) -> provider_message.MessageChunk:
|
||||||
"""Invoke LLM streaming and yield chunks."""
|
"""Invoke LLM streaming and yield chunks."""
|
||||||
req_messages = self._convert_messages(messages)
|
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True)
|
||||||
|
|
||||||
model_name = self._build_litellm_model_name(model.model_entity.name)
|
|
||||||
api_key = model.provider.token_mgr.get_token()
|
|
||||||
|
|
||||||
args = {
|
|
||||||
'model': model_name,
|
|
||||||
'messages': req_messages,
|
|
||||||
'api_key': api_key,
|
|
||||||
'stream': True,
|
|
||||||
}
|
|
||||||
self._build_common_args(args)
|
|
||||||
args.update(extra_args)
|
|
||||||
|
|
||||||
if funcs:
|
|
||||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
|
||||||
if tools:
|
|
||||||
args['tools'] = tools
|
|
||||||
|
|
||||||
chunk_idx = 0
|
chunk_idx = 0
|
||||||
role = 'assistant'
|
role = 'assistant'
|
||||||
@@ -223,6 +213,19 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
|||||||
try:
|
try:
|
||||||
response = await acompletion(**args)
|
response = await acompletion(**args)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
# Check for usage chunk (final chunk with stream_options include_usage)
|
||||||
|
if hasattr(chunk, 'usage') and chunk.usage and (not hasattr(chunk, 'choices') or not chunk.choices):
|
||||||
|
usage_info = {
|
||||||
|
'prompt_tokens': chunk.usage.prompt_tokens or 0,
|
||||||
|
'completion_tokens': chunk.usage.completion_tokens or 0,
|
||||||
|
'total_tokens': chunk.usage.total_tokens or 0,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
if query.variables is None:
|
||||||
|
query.variables = {}
|
||||||
|
query.variables['_stream_usage'] = usage_info
|
||||||
|
continue
|
||||||
|
|
||||||
if not hasattr(chunk, 'choices') or not chunk.choices:
|
if not hasattr(chunk, 'choices') or not chunk.choices:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -57,41 +57,6 @@ class ToolManager:
|
|||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list:
|
|
||||||
"""为anthropic生成函数列表
|
|
||||||
|
|
||||||
e.g.
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "get_stock_price",
|
|
||||||
"description": "Get the current stock price for a given ticker symbol.",
|
|
||||||
"input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"ticker": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The stock ticker symbol, e.g. AAPL for Apple Inc."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["ticker"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
|
|
||||||
for function in use_funcs:
|
|
||||||
function_schema = {
|
|
||||||
'name': function.name,
|
|
||||||
'description': function.description,
|
|
||||||
'input_schema': function.parameters,
|
|
||||||
}
|
|
||||||
tools.append(function_schema)
|
|
||||||
|
|
||||||
return tools
|
|
||||||
|
|
||||||
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
|
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
|
||||||
"""执行函数调用"""
|
"""执行函数调用"""
|
||||||
|
|
||||||
|
|||||||
@@ -93,8 +93,6 @@ class TestExtractUsage:
|
|||||||
assert result['prompt_tokens'] == 100
|
assert result['prompt_tokens'] == 100
|
||||||
assert result['completion_tokens'] == 50
|
assert result['completion_tokens'] == 50
|
||||||
assert result['total_tokens'] == 150
|
assert result['total_tokens'] == 150
|
||||||
assert result['input_tokens'] == 100 # Compatibility alias
|
|
||||||
assert result['output_tokens'] == 50 # Compatibility alias
|
|
||||||
|
|
||||||
def test_extract_usage_with_zero_values(self):
|
def test_extract_usage_with_zero_values(self):
|
||||||
"""Test extraction when values are 0"""
|
"""Test extraction when values are 0"""
|
||||||
|
|||||||
Reference in New Issue
Block a user