diff --git a/src/langbot/pkg/agent/runner/context_builder.py b/src/langbot/pkg/agent/runner/context_builder.py index 7da30b40f..7fd86e04c 100644 --- a/src/langbot/pkg/agent/runner/context_builder.py +++ b/src/langbot/pkg/agent/runner/context_builder.py @@ -184,7 +184,7 @@ class AgentRunContextBuilder: def _is_llm_model_resource(model_resource: ModelResource) -> bool: operations = model_resource.get('operations') if isinstance(operations, list) and operations: - return bool({'invoke', 'stream'} & {str(operation) for operation in operations}) + return bool({'invoke', 'stream', 'count_tokens'} & {str(operation) for operation in operations}) return model_resource.get('model_type') != 'rerank' async def _build_model_context_window_tokens(self, resources: AgentResources) -> int | None: diff --git a/src/langbot/pkg/agent/runner/resource_builder.py b/src/langbot/pkg/agent/runner/resource_builder.py index a2e8c3667..32238a1c6 100644 --- a/src/langbot/pkg/agent/runner/resource_builder.py +++ b/src/langbot/pkg/agent/runner/resource_builder.py @@ -101,9 +101,9 @@ class AgentResourceBuilder: seen_model_ids: set[str] = set() model_perms = set(manifest_perms.models) - include_llm = bool({'invoke', 'stream'} & model_perms) + include_llm = bool({'invoke', 'stream', 'count_tokens'} & model_perms) include_rerank = 'rerank' in model_perms - llm_operations = [operation for operation in ('invoke', 'stream') if operation in model_perms] + llm_operations = [operation for operation in ('invoke', 'stream', 'count_tokens') if operation in model_perms] if not include_llm and not include_rerank: return models diff --git a/src/langbot/pkg/agent/runner/session_registry.py b/src/langbot/pkg/agent/runner/session_registry.py index 2d2a316bb..8179ef90f 100644 --- a/src/langbot/pkg/agent/runner/session_registry.py +++ b/src/langbot/pkg/agent/runner/session_registry.py @@ -13,7 +13,7 @@ from .context_builder import AgentResources MAX_STEERING_QUEUE_ITEMS = 100 DEFAULT_RESOURCE_OPERATIONS: dict[str, set[str]] = { - 'model': {'invoke', 'stream', 'rerank'}, + 'model': {'invoke', 'stream', 'rerank', 'count_tokens'}, 'tool': {'detail', 'call'}, 'knowledge_base': {'list', 'retrieve'}, 'skill': {'activate'}, diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index befe2f207..43fa27aef 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -556,6 +556,55 @@ class RuntimeConnectionHandler(handler.Handler): }, ) + @self.action(PluginToRuntimeAction.COUNT_TOKENS) + async def count_tokens(data: dict[str, Any]) -> handler.ActionResponse: + """Count model input tokens. + + For AgentRunner calls: requires run_id and validates model_uuid against session.resources.models. + For regular plugin calls: no run_id, unrestricted access (backward compatibility). + """ + llm_model_uuid = data['llm_model_uuid'] + messages = data['messages'] + funcs = data.get('funcs', []) + extra_args = data.get('extra_args', {}) + run_id = data.get('run_id') + caller_plugin_identity = data.get('caller_plugin_identity') + + if run_id: + _session, error = await _validate_run_authorization( + run_id, 'model', llm_model_uuid, self.ap, caller_plugin_identity, operation='count_tokens' + ) + if error: + return error + + llm_model = await self.ap.model_mgr.get_model_by_uuid(llm_model_uuid) + if llm_model is None: + return handler.ActionResponse.error( + message=f'LLM model with llm_model_uuid {llm_model_uuid} not found', + ) + + messages_obj = [provider_message.Message.model_validate(message) for message in messages] + + async def _placeholder_func(**kwargs): + pass + + funcs_obj = [resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func}) for func in funcs] + count_tokens_method = getattr(llm_model.provider.requester, 'count_tokens', None) + if not callable(count_tokens_method): + return handler.ActionResponse.error(message='LLM provider does not support token counting') + + try: + tokens = await count_tokens_method( + model=llm_model, + messages=messages_obj, + funcs=funcs_obj, + extra_args=extra_args, + ) + except Exception as exc: + return handler.ActionResponse.error(message=f'Token counting failed: {exc}') + + return handler.ActionResponse.success(data={'tokens': tokens}) + @self.action(PluginToRuntimeAction.INVOKE_LLM) async def invoke_llm(data: dict[str, Any]) -> handler.ActionResponse: """Invoke llm diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index 377f7d4a8..80f1bc2e9 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -411,6 +411,20 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): """ pass + async def count_tokens( + self, + model: RuntimeLLMModel, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, + extra_args: dict[str, typing.Any] = {}, + ) -> int: + """Count model input tokens before invoking the model. + + Requesters should use the same provider/model conversion path as + ``invoke_llm`` so the preflight count matches the actual request shape. + """ + raise NotImplementedError('This requester does not support token counting') + async def invoke_llm_stream( self, query: pipeline_query.Query, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index c1b5ae0b6..d94b5bb76 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -521,6 +521,33 @@ class LiteLLMRequester(requester.ProviderAPIRequester): return args + async def count_tokens( + self, + model: requester.RuntimeLLMModel, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, + extra_args: dict[str, typing.Any] = {}, + ) -> int: + """Count input tokens with LiteLLM's model-aware tokenizer.""" + args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False) + count_args: dict[str, typing.Any] = { + 'model': args['model'], + 'messages': args['messages'], + } + if 'tools' in args: + count_args['tools'] = args['tools'] + if 'tool_choice' in args: + count_args['tool_choice'] = args['tool_choice'] + + try: + tokens = litellm.token_counter(**count_args) + except Exception as e: + self._handle_litellm_error(e) + + if isinstance(tokens, bool) or not isinstance(tokens, int) or tokens < 0: + raise errors.RequesterError(f'token counter returned invalid value: {tokens!r}') + return tokens + async def invoke_llm( self, query: pipeline_query.Query, diff --git a/tests/unit_tests/agent/conftest.py b/tests/unit_tests/agent/conftest.py index a55dccf1c..f4cba930d 100644 --- a/tests/unit_tests/agent/conftest.py +++ b/tests/unit_tests/agent/conftest.py @@ -77,7 +77,7 @@ def make_session( } authorized_operations: dict[str, dict[str, set[str]]] = { 'model': { - m.get('model_id'): set(m.get('operations') or ['invoke', 'stream', 'rerank']) + m.get('model_id'): set(m.get('operations') or ['invoke', 'stream', 'rerank', 'count_tokens']) for m in res.get('models', []) if m.get('model_id') }, diff --git a/tests/unit_tests/agent/test_resource_builder.py b/tests/unit_tests/agent/test_resource_builder.py index dcb9d4ffb..ed642fa4c 100644 --- a/tests/unit_tests/agent/test_resource_builder.py +++ b/tests/unit_tests/agent/test_resource_builder.py @@ -14,7 +14,7 @@ from langbot.pkg.agent.runner.resource_builder import AgentResourceBuilder RUNNER_ID = 'plugin:test/runner/default' FULL_PERMISSIONS = { - 'models': ['invoke', 'stream', 'rerank'], + 'models': ['count_tokens', 'invoke', 'stream', 'rerank'], 'tools': ['detail', 'call'], 'knowledge_bases': ['list', 'retrieve'], 'history': ['page', 'search'], @@ -139,9 +139,24 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app resources = await build_resources(app, query, descriptor) assert resources['models'] == [ - {'model_id': 'primary', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']}, - {'model_id': 'fallback', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']}, - {'model_id': 'aux', 'model_type': 'llm', 'provider': 'aux-provider', 'operations': ['invoke', 'stream']}, + { + 'model_id': 'primary', + 'model_type': 'llm', + 'provider': 'test-provider', + 'operations': ['invoke', 'stream', 'count_tokens'], + }, + { + 'model_id': 'fallback', + 'model_type': 'llm', + 'provider': 'test-provider', + 'operations': ['invoke', 'stream', 'count_tokens'], + }, + { + 'model_id': 'aux', + 'model_type': 'llm', + 'provider': 'aux-provider', + 'operations': ['invoke', 'stream', 'count_tokens'], + }, {'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']}, ] @@ -189,7 +204,12 @@ async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app): resources = await build_resources(app, query, descriptor) assert resources['models'] == [ - {'model_id': 'llm', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']}, + { + 'model_id': 'llm', + 'model_type': 'llm', + 'provider': 'test-provider', + 'operations': ['invoke', 'stream', 'count_tokens'], + }, {'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']}, ] @@ -222,7 +242,12 @@ async def test_build_resources_accepts_dynamic_form_type_aliases(app): resources = await build_resources(app, query, descriptor) assert resources['models'] == [ - {'model_id': 'llm_alias', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']}, + { + 'model_id': 'llm_alias', + 'model_type': 'llm', + 'provider': 'test-provider', + 'operations': ['invoke', 'stream', 'count_tokens'], + }, ] assert resources['knowledge_bases'] == [ {'kb_id': 'kb_alias', 'kb_name': 'name-kb_alias', 'kb_type': 'default', 'operations': ['list', 'retrieve']}, diff --git a/tests/unit_tests/plugin/test_handler_actions.py b/tests/unit_tests/plugin/test_handler_actions.py index a371c239d..dae2fe009 100644 --- a/tests/unit_tests/plugin/test_handler_actions.py +++ b/tests/unit_tests/plugin/test_handler_actions.py @@ -615,6 +615,94 @@ class TestAgentRunProxyActions: assert response.data['usage'] == usage assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables + @pytest.mark.asyncio + async def test_count_tokens_validates_run_authorization_and_calls_provider(self, app): + """COUNT_TOKENS is run-scoped and forwards messages/tools to the model requester.""" + from langbot.pkg.agent.runner.session_registry import get_session_registry + + run_id = 'run_proxy_count_tokens' + query = self.query() + app.query_pool.cached_queries[906] = query + + registry = get_session_registry() + await registry.unregister(run_id) + await registry.register( + run_id=run_id, + runner_id='plugin:test/runner/default', + query_id=906, + plugin_identity='test/runner', + resources=make_agent_resources( + models=[{'model_id': 'llm_count_001', 'operations': ['count_tokens']}], + ), + ) + + requester = SimpleNamespace(count_tokens=AsyncMock(return_value=37)) + model = SimpleNamespace( + model_entity=SimpleNamespace(abilities=[], extra_args={'temperature': 0.2}), + provider=SimpleNamespace(requester=requester), + ) + app.model_mgr.get_model_by_uuid.return_value = model + runtime_handler = make_handler(app) + + try: + response = await runtime_handler.actions[PluginToRuntimeAction.COUNT_TOKENS.value]({ + 'run_id': run_id, + 'caller_plugin_identity': 'test/runner', + 'llm_model_uuid': 'llm_count_001', + 'messages': [{'role': 'user', 'content': 'hello'}], + 'funcs': [{ + 'name': 'search', + 'human_desc': 'Search', + 'description': 'Search', + 'parameters': {'type': 'object'}, + }], + 'extra_args': {'temperature': 0.7}, + }) + finally: + await registry.unregister(run_id) + + assert response.code == 0 + assert response.data == {'tokens': 37} + requester.count_tokens.assert_awaited_once() + kwargs = requester.count_tokens.await_args.kwargs + assert kwargs['model'] is model + assert kwargs['messages'][0].content == 'hello' + assert [tool.name for tool in kwargs['funcs']] == ['search'] + assert kwargs['extra_args'] == {'temperature': 0.7} + + @pytest.mark.asyncio + async def test_count_tokens_rejects_model_without_operation(self, app): + """COUNT_TOKENS requires the explicit model operation in the run snapshot.""" + from langbot.pkg.agent.runner.session_registry import get_session_registry + + run_id = 'run_proxy_count_tokens_denied' + registry = get_session_registry() + await registry.unregister(run_id) + await registry.register( + run_id=run_id, + runner_id='plugin:test/runner/default', + query_id=None, + plugin_identity='test/runner', + resources=make_agent_resources( + models=[{'model_id': 'llm_count_002', 'operations': ['invoke']}], + ), + ) + + runtime_handler = make_handler(app) + try: + response = await runtime_handler.actions[PluginToRuntimeAction.COUNT_TOKENS.value]({ + 'run_id': run_id, + 'caller_plugin_identity': 'test/runner', + 'llm_model_uuid': 'llm_count_002', + 'messages': [{'role': 'user', 'content': 'hello'}], + }) + finally: + await registry.unregister(run_id) + + assert response.code != 0 + assert 'operation count_tokens' in response.message + app.model_mgr.get_model_by_uuid.assert_not_awaited() + @pytest.mark.asyncio async def test_invoke_llm_stream_restores_query_and_options(self, app): """INVOKE_LLM_STREAM applies the same host context as non-streaming calls.""" diff --git a/tests/unit_tests/provider/test_provider_specific_fields.py b/tests/unit_tests/provider/test_provider_specific_fields.py index 2520e608e..f9ae9a8fc 100644 --- a/tests/unit_tests/provider/test_provider_specific_fields.py +++ b/tests/unit_tests/provider/test_provider_specific_fields.py @@ -1,11 +1,17 @@ -"""Unit tests for provider_specific_fields round-trip in LiteLLMRequester. +"""Unit tests for LiteLLMRequester message/tool conversion. -This tests the fix for GitHub issue #1899: Gemini requires thought_signature -to be preserved across tool call rounds for function calls to work correctly. +This includes provider_specific_fields round-trip coverage for GitHub issue +#1899 and token counting preflight behavior for AgentRunner context budgeting. """ -import langbot_plugin.api.entities.builtin.provider.message as provider_message +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock, patch +import pytest +import langbot_plugin.api.entities.builtin.provider.message as provider_message +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool + +from langbot.pkg.provider.modelmgr import requester as model_requester from langbot.pkg.provider.modelmgr.requesters.litellmchat import LiteLLMRequester @@ -14,6 +20,84 @@ def _make_requester() -> LiteLLMRequester: return LiteLLMRequester.__new__(LiteLLMRequester) +def _make_configured_requester() -> LiteLLMRequester: + req = LiteLLMRequester.__new__(LiteLLMRequester) + req.requester_cfg = { + 'base_url': '', + 'timeout': 120, + 'custom_llm_provider': 'openai', + 'drop_params': False, + 'num_retries': 0, + 'api_version': '', + } + req.ap = SimpleNamespace( + tool_mgr=SimpleNamespace( + generate_tools_for_openai=AsyncMock( + return_value=[ + { + 'type': 'function', + 'function': { + 'name': 'search', + 'description': 'Search', + 'parameters': {'type': 'object'}, + }, + } + ] + ) + ) + ) + return req + + +def _make_runtime_model() -> model_requester.RuntimeLLMModel: + provider = SimpleNamespace(token_mgr=SimpleNamespace(get_token=Mock(return_value='sk-test'))) + return SimpleNamespace( + model_entity=SimpleNamespace( + name='gpt-4.1', + extra_args={'temperature': 0.2}, + ), + provider=provider, + ) + + +@pytest.mark.asyncio +async def test_count_tokens_uses_litellm_counter_with_request_messages_and_tools(): + """Token preflight uses the same LiteLLM request shape as chat completion.""" + req = _make_configured_requester() + model = _make_runtime_model() + tool = resource_tool.LLMTool( + name='search', + human_desc='Search', + description='Search', + parameters={'type': 'object'}, + func=lambda **kwargs: None, + ) + + with patch('langbot.pkg.provider.modelmgr.requesters.litellmchat.litellm.token_counter', return_value=42) as counter: + tokens = await req.count_tokens( + model=model, + messages=[ + provider_message.Message( + role='user', + content=[ + provider_message.ContentElement(type='text', text='hello'), + provider_message.ContentElement(type='file_url', file_url='https://example.test/a.pdf'), + ], + ) + ], + funcs=[tool], + extra_args={'presence_penalty': 0.1}, + ) + + assert tokens == 42 + counter.assert_called_once() + kwargs = counter.call_args.kwargs + assert kwargs['model'] == 'openai/gpt-4.1' + assert kwargs['messages'] == [{'role': 'user', 'content': [{'type': 'text', 'text': 'hello'}]}] + assert kwargs['tools'][0]['function']['name'] == 'search' + assert kwargs['tool_choice'] == 'auto' + + def test_convert_messages_preserves_tool_call_provider_specific_fields(): """Tool calls should retain provider_specific_fields through _convert_messages.""" req = _make_requester()