diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index cb9a4183..b673c758 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -67,8 +67,8 @@ class RuntimeProvider: if isinstance(result, tuple): msg, usage_info = result if usage_info: - input_tokens = usage_info.get('input_tokens', 0) - output_tokens = usage_info.get('output_tokens', 0) + input_tokens = usage_info.get('prompt_tokens', 0) + output_tokens = usage_info.get('completion_tokens', 0) return msg else: return result @@ -128,7 +128,6 @@ class RuntimeProvider: start_time = time.time() status = 'success' error_message = None - # Note: Stream doesn't easily provide token counts, set to 0 input_tokens = 0 output_tokens = 0 @@ -143,6 +142,15 @@ class RuntimeProvider: remove_think=remove_think, ): yield chunk + # Extract usage from stream if available (stored by LiteLLM requester) + if query: + if query.variables is None: + query.variables = {} + if '_stream_usage' in query.variables: + usage_info = query.variables['_stream_usage'] + input_tokens = usage_info.get('prompt_tokens', 0) + output_tokens = usage_info.get('completion_tokens', 0) + del query.variables['_stream_usage'] except Exception as e: status = 'error' error_message = str(e) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index 85f89f97..ae776e4d 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -88,17 +88,11 @@ class LiteLLMRequester(requester.ProviderAPIRequester): def _extract_usage(self, response) -> dict: """Extract usage info from LiteLLM response.""" usage = response.usage - usage_info = { + return { 'prompt_tokens': usage.prompt_tokens or 0, 'completion_tokens': usage.completion_tokens or 0, 'total_tokens': usage.total_tokens or 0, } - # TODO: LangBot internal inconsistency - LLM monitoring uses input_tokens/output_tokens, - # while embedding monitoring uses prompt_tokens. Should unify in requester.py to use - # prompt_tokens (OpenAI native) consistently. After that, remove these compatibility aliases. - usage_info['input_tokens'] = usage_info['prompt_tokens'] - usage_info['output_tokens'] = usage_info['completion_tokens'] - return usage_info def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict: """Apply common requester config to args dict.""" @@ -136,6 +130,37 @@ class LiteLLMRequester(requester.ProviderAPIRequester): raise errors.RequesterError(f'API 错误: {str(e)}') raise errors.RequesterError(f'未知错误: {str(e)}') + async def _build_completion_args( + self, + model: requester.RuntimeLLMModel, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, + extra_args: dict[str, typing.Any] = {}, + stream: bool = False, + ) -> dict: + """Build common completion arguments for invoke_llm and invoke_llm_stream.""" + req_messages = self._convert_messages(messages) + model_name = self._build_litellm_model_name(model.model_entity.name) + api_key = model.provider.token_mgr.get_token() + + args = { + 'model': model_name, + 'messages': req_messages, + 'api_key': api_key, + } + if stream: + args['stream'] = True + args['stream_options'] = {'include_usage': True} + self._build_common_args(args) + args.update(extra_args) + + if funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) + if tools: + args['tools'] = tools + + return args + async def invoke_llm( self, query: pipeline_query.Query, @@ -146,25 +171,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): remove_think: bool = False, ) -> tuple[provider_message.Message, dict]: """Invoke LLM and return message with usage info.""" - # DO NOT modify input messages - copy them - req_messages = self._convert_messages(messages) - - model_name = self._build_litellm_model_name(model.model_entity.name) - - api_key = model.provider.token_mgr.get_token() - - args = { - 'model': model_name, - 'messages': req_messages, - 'api_key': api_key, - } - self._build_common_args(args) - args.update(extra_args) - - if funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) - if tools: - args['tools'] = tools + args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False) try: response = await acompletion(**args) @@ -198,24 +205,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): remove_think: bool = False, ) -> provider_message.MessageChunk: """Invoke LLM streaming and yield chunks.""" - req_messages = self._convert_messages(messages) - - model_name = self._build_litellm_model_name(model.model_entity.name) - api_key = model.provider.token_mgr.get_token() - - args = { - 'model': model_name, - 'messages': req_messages, - 'api_key': api_key, - 'stream': True, - } - self._build_common_args(args) - args.update(extra_args) - - if funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) - if tools: - args['tools'] = tools + args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True) chunk_idx = 0 role = 'assistant' @@ -223,6 +213,19 @@ class LiteLLMRequester(requester.ProviderAPIRequester): try: response = await acompletion(**args) async for chunk in response: + # Check for usage chunk (final chunk with stream_options include_usage) + if hasattr(chunk, 'usage') and chunk.usage and (not hasattr(chunk, 'choices') or not chunk.choices): + usage_info = { + 'prompt_tokens': chunk.usage.prompt_tokens or 0, + 'completion_tokens': chunk.usage.completion_tokens or 0, + 'total_tokens': chunk.usage.total_tokens or 0, + } + if query: + if query.variables is None: + query.variables = {} + query.variables['_stream_usage'] = usage_info + continue + if not hasattr(chunk, 'choices') or not chunk.choices: continue diff --git a/src/langbot/pkg/provider/tools/toolmgr.py b/src/langbot/pkg/provider/tools/toolmgr.py index 5c510fcd..3154387a 100644 --- a/src/langbot/pkg/provider/tools/toolmgr.py +++ b/src/langbot/pkg/provider/tools/toolmgr.py @@ -83,19 +83,6 @@ class ToolManager: return tools - async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list: - tools = [] - - for function in use_funcs: - function_schema = { - 'name': function.name, - 'description': function.description, - 'input_schema': function.parameters, - } - tools.append(function_schema) - - return tools - async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any: if await self.native_tool_loader.has_tool(name): return await self.native_tool_loader.invoke_tool(name, parameters, query) diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index fad81683..d80ea50e 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -93,8 +93,6 @@ class TestExtractUsage: assert result['prompt_tokens'] == 100 assert result['completion_tokens'] == 50 assert result['total_tokens'] == 150 - assert result['input_tokens'] == 100 # Compatibility alias - assert result['output_tokens'] == 50 # Compatibility alias def test_extract_usage_with_zero_values(self): """Test extraction when values are 0""" diff --git a/tests/unit_tests/provider/test_tool_manager.py b/tests/unit_tests/provider/test_tool_manager.py index 8e8439f5..fbfcb13f 100644 --- a/tests/unit_tests/provider/test_tool_manager.py +++ b/tests/unit_tests/provider/test_tool_manager.py @@ -1,7 +1,7 @@ """Unit tests for ToolManager. Tests cover: -- Tool schema generation for OpenAI and Anthropic +- Tool schema generation for OpenAI/LiteLLM - Tool execution dispatch """ @@ -109,28 +109,6 @@ class TestToolManagerSchemaGeneration: assert tool2['type'] == 'function' assert tool2['function']['name'] == 'calculate' - @pytest.mark.asyncio - async def test_generate_tools_for_anthropic(self, mock_app, sample_tools): - """Test that generate_tools_for_anthropic produces correct schema.""" - toolmgr = get_toolmgr_module() - - manager = toolmgr.ToolManager(mock_app) - result = await manager.generate_tools_for_anthropic(sample_tools) - - assert len(result) == 2 - - # Verify first tool schema (Anthropic format) - tool1 = result[0] - assert tool1['name'] == 'get_weather' - assert tool1['description'] == 'Get current weather for a location' - assert 'input_schema' in tool1 - assert tool1['input_schema']['type'] == 'object' - - # Verify second tool schema - tool2 = result[1] - assert tool2['name'] == 'calculate' - assert 'input_schema' in tool2 - @pytest.mark.asyncio async def test_generate_tools_empty_list(self, mock_app): """Test that generating tools from empty list returns empty list.""" @@ -141,9 +119,6 @@ class TestToolManagerSchemaGeneration: openai_result = await manager.generate_tools_for_openai([]) assert openai_result == [] - anthropic_result = await manager.generate_tools_for_anthropic([]) - assert anthropic_result == [] - @pytest.mark.asyncio async def test_openai_schema_fields_complete(self, mock_app, sample_tools): """Test that OpenAI schema includes all required fields.""" @@ -161,20 +136,6 @@ class TestToolManagerSchemaGeneration: assert 'description' in func assert 'parameters' in func - @pytest.mark.asyncio - async def test_anthropic_schema_fields_complete(self, mock_app, sample_tools): - """Test that Anthropic schema includes all required fields.""" - toolmgr = get_toolmgr_module() - - manager = toolmgr.ToolManager(mock_app) - result = await manager.generate_tools_for_anthropic(sample_tools) - - for tool_schema in result: - assert 'name' in tool_schema - assert 'description' in tool_schema - assert 'input_schema' in tool_schema - - class TestToolManagerExecuteFuncCall: """Tests for execute_func_call method."""