diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index 08fee3ab..5963d9e7 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -67,8 +67,8 @@ class RuntimeProvider: if isinstance(result, tuple): msg, usage_info = result if usage_info: - input_tokens = usage_info.get('input_tokens', 0) - output_tokens = usage_info.get('output_tokens', 0) + input_tokens = usage_info.get('prompt_tokens', 0) + output_tokens = usage_info.get('completion_tokens', 0) return msg else: return result @@ -128,7 +128,6 @@ class RuntimeProvider: start_time = time.time() status = 'success' error_message = None - # Note: Stream doesn't easily provide token counts, set to 0 input_tokens = 0 output_tokens = 0 @@ -143,6 +142,15 @@ class RuntimeProvider: remove_think=remove_think, ): yield chunk + # Extract usage from stream if available (stored by LiteLLM requester) + if query: + if query.variables is None: + query.variables = {} + if '_stream_usage' in query.variables: + usage_info = query.variables['_stream_usage'] + input_tokens = usage_info.get('prompt_tokens', 0) + output_tokens = usage_info.get('completion_tokens', 0) + del query.variables['_stream_usage'] except Exception as e: status = 'error' error_message = str(e) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index 85f89f97..ae776e4d 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -88,17 +88,11 @@ class LiteLLMRequester(requester.ProviderAPIRequester): def _extract_usage(self, response) -> dict: """Extract usage info from LiteLLM response.""" usage = response.usage - usage_info = { + return { 'prompt_tokens': usage.prompt_tokens or 0, 'completion_tokens': usage.completion_tokens or 0, 'total_tokens': usage.total_tokens or 0, } - # TODO: LangBot internal inconsistency - LLM monitoring uses input_tokens/output_tokens, - # while embedding monitoring uses prompt_tokens. Should unify in requester.py to use - # prompt_tokens (OpenAI native) consistently. After that, remove these compatibility aliases. - usage_info['input_tokens'] = usage_info['prompt_tokens'] - usage_info['output_tokens'] = usage_info['completion_tokens'] - return usage_info def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict: """Apply common requester config to args dict.""" @@ -136,6 +130,37 @@ class LiteLLMRequester(requester.ProviderAPIRequester): raise errors.RequesterError(f'API 错误: {str(e)}') raise errors.RequesterError(f'未知错误: {str(e)}') + async def _build_completion_args( + self, + model: requester.RuntimeLLMModel, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, + extra_args: dict[str, typing.Any] = {}, + stream: bool = False, + ) -> dict: + """Build common completion arguments for invoke_llm and invoke_llm_stream.""" + req_messages = self._convert_messages(messages) + model_name = self._build_litellm_model_name(model.model_entity.name) + api_key = model.provider.token_mgr.get_token() + + args = { + 'model': model_name, + 'messages': req_messages, + 'api_key': api_key, + } + if stream: + args['stream'] = True + args['stream_options'] = {'include_usage': True} + self._build_common_args(args) + args.update(extra_args) + + if funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) + if tools: + args['tools'] = tools + + return args + async def invoke_llm( self, query: pipeline_query.Query, @@ -146,25 +171,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): remove_think: bool = False, ) -> tuple[provider_message.Message, dict]: """Invoke LLM and return message with usage info.""" - # DO NOT modify input messages - copy them - req_messages = self._convert_messages(messages) - - model_name = self._build_litellm_model_name(model.model_entity.name) - - api_key = model.provider.token_mgr.get_token() - - args = { - 'model': model_name, - 'messages': req_messages, - 'api_key': api_key, - } - self._build_common_args(args) - args.update(extra_args) - - if funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) - if tools: - args['tools'] = tools + args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False) try: response = await acompletion(**args) @@ -198,24 +205,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): remove_think: bool = False, ) -> provider_message.MessageChunk: """Invoke LLM streaming and yield chunks.""" - req_messages = self._convert_messages(messages) - - model_name = self._build_litellm_model_name(model.model_entity.name) - api_key = model.provider.token_mgr.get_token() - - args = { - 'model': model_name, - 'messages': req_messages, - 'api_key': api_key, - 'stream': True, - } - self._build_common_args(args) - args.update(extra_args) - - if funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) - if tools: - args['tools'] = tools + args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True) chunk_idx = 0 role = 'assistant' @@ -223,6 +213,19 @@ class LiteLLMRequester(requester.ProviderAPIRequester): try: response = await acompletion(**args) async for chunk in response: + # Check for usage chunk (final chunk with stream_options include_usage) + if hasattr(chunk, 'usage') and chunk.usage and (not hasattr(chunk, 'choices') or not chunk.choices): + usage_info = { + 'prompt_tokens': chunk.usage.prompt_tokens or 0, + 'completion_tokens': chunk.usage.completion_tokens or 0, + 'total_tokens': chunk.usage.total_tokens or 0, + } + if query: + if query.variables is None: + query.variables = {} + query.variables['_stream_usage'] = usage_info + continue + if not hasattr(chunk, 'choices') or not chunk.choices: continue diff --git a/src/langbot/pkg/provider/tools/toolmgr.py b/src/langbot/pkg/provider/tools/toolmgr.py index f921c094..dc3d27e8 100644 --- a/src/langbot/pkg/provider/tools/toolmgr.py +++ b/src/langbot/pkg/provider/tools/toolmgr.py @@ -57,41 +57,6 @@ class ToolManager: return tools - async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list: - """为anthropic生成函数列表 - - e.g. - - [ - { - "name": "get_stock_price", - "description": "Get the current stock price for a given ticker symbol.", - "input_schema": { - "type": "object", - "properties": { - "ticker": { - "type": "string", - "description": "The stock ticker symbol, e.g. AAPL for Apple Inc." - } - }, - "required": ["ticker"] - } - } - ] - """ - - tools = [] - - for function in use_funcs: - function_schema = { - 'name': function.name, - 'description': function.description, - 'input_schema': function.parameters, - } - tools.append(function_schema) - - return tools - async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any: """执行函数调用""" diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index fad81683..d80ea50e 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -93,8 +93,6 @@ class TestExtractUsage: assert result['prompt_tokens'] == 100 assert result['completion_tokens'] == 50 assert result['total_tokens'] == 150 - assert result['input_tokens'] == 100 # Compatibility alias - assert result['output_tokens'] == 50 # Compatibility alias def test_extract_usage_with_zero_values(self): """Test extraction when values are 0"""