mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-26 23:44:19 +00:00
feat(agent-runner): support scoped token counting
This commit is contained in:
@@ -184,7 +184,7 @@ class AgentRunContextBuilder:
|
||||
def _is_llm_model_resource(model_resource: ModelResource) -> bool:
|
||||
operations = model_resource.get('operations')
|
||||
if isinstance(operations, list) and operations:
|
||||
return bool({'invoke', 'stream'} & {str(operation) for operation in operations})
|
||||
return bool({'invoke', 'stream', 'count_tokens'} & {str(operation) for operation in operations})
|
||||
return model_resource.get('model_type') != 'rerank'
|
||||
|
||||
async def _build_model_context_window_tokens(self, resources: AgentResources) -> int | None:
|
||||
|
||||
@@ -101,9 +101,9 @@ class AgentResourceBuilder:
|
||||
seen_model_ids: set[str] = set()
|
||||
|
||||
model_perms = set(manifest_perms.models)
|
||||
include_llm = bool({'invoke', 'stream'} & model_perms)
|
||||
include_llm = bool({'invoke', 'stream', 'count_tokens'} & model_perms)
|
||||
include_rerank = 'rerank' in model_perms
|
||||
llm_operations = [operation for operation in ('invoke', 'stream') if operation in model_perms]
|
||||
llm_operations = [operation for operation in ('invoke', 'stream', 'count_tokens') if operation in model_perms]
|
||||
if not include_llm and not include_rerank:
|
||||
return models
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from .context_builder import AgentResources
|
||||
MAX_STEERING_QUEUE_ITEMS = 100
|
||||
|
||||
DEFAULT_RESOURCE_OPERATIONS: dict[str, set[str]] = {
|
||||
'model': {'invoke', 'stream', 'rerank'},
|
||||
'model': {'invoke', 'stream', 'rerank', 'count_tokens'},
|
||||
'tool': {'detail', 'call'},
|
||||
'knowledge_base': {'list', 'retrieve'},
|
||||
'skill': {'activate'},
|
||||
|
||||
@@ -556,6 +556,55 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
},
|
||||
)
|
||||
|
||||
@self.action(PluginToRuntimeAction.COUNT_TOKENS)
|
||||
async def count_tokens(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Count model input tokens.
|
||||
|
||||
For AgentRunner calls: requires run_id and validates model_uuid against session.resources.models.
|
||||
For regular plugin calls: no run_id, unrestricted access (backward compatibility).
|
||||
"""
|
||||
llm_model_uuid = data['llm_model_uuid']
|
||||
messages = data['messages']
|
||||
funcs = data.get('funcs', [])
|
||||
extra_args = data.get('extra_args', {})
|
||||
run_id = data.get('run_id')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if run_id:
|
||||
_session, error = await _validate_run_authorization(
|
||||
run_id, 'model', llm_model_uuid, self.ap, caller_plugin_identity, operation='count_tokens'
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
llm_model = await self.ap.model_mgr.get_model_by_uuid(llm_model_uuid)
|
||||
if llm_model is None:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'LLM model with llm_model_uuid {llm_model_uuid} not found',
|
||||
)
|
||||
|
||||
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
|
||||
|
||||
async def _placeholder_func(**kwargs):
|
||||
pass
|
||||
|
||||
funcs_obj = [resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func}) for func in funcs]
|
||||
count_tokens_method = getattr(llm_model.provider.requester, 'count_tokens', None)
|
||||
if not callable(count_tokens_method):
|
||||
return handler.ActionResponse.error(message='LLM provider does not support token counting')
|
||||
|
||||
try:
|
||||
tokens = await count_tokens_method(
|
||||
model=llm_model,
|
||||
messages=messages_obj,
|
||||
funcs=funcs_obj,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
except Exception as exc:
|
||||
return handler.ActionResponse.error(message=f'Token counting failed: {exc}')
|
||||
|
||||
return handler.ActionResponse.success(data={'tokens': tokens})
|
||||
|
||||
@self.action(PluginToRuntimeAction.INVOKE_LLM)
|
||||
async def invoke_llm(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Invoke llm
|
||||
|
||||
@@ -411,6 +411,20 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model: RuntimeLLMModel,
|
||||
messages: typing.List[provider_message.Message],
|
||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> int:
|
||||
"""Count model input tokens before invoking the model.
|
||||
|
||||
Requesters should use the same provider/model conversion path as
|
||||
``invoke_llm`` so the preflight count matches the actual request shape.
|
||||
"""
|
||||
raise NotImplementedError('This requester does not support token counting')
|
||||
|
||||
async def invoke_llm_stream(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
|
||||
@@ -521,6 +521,33 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
|
||||
return args
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[provider_message.Message],
|
||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> int:
|
||||
"""Count input tokens with LiteLLM's model-aware tokenizer."""
|
||||
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False)
|
||||
count_args: dict[str, typing.Any] = {
|
||||
'model': args['model'],
|
||||
'messages': args['messages'],
|
||||
}
|
||||
if 'tools' in args:
|
||||
count_args['tools'] = args['tools']
|
||||
if 'tool_choice' in args:
|
||||
count_args['tool_choice'] = args['tool_choice']
|
||||
|
||||
try:
|
||||
tokens = litellm.token_counter(**count_args)
|
||||
except Exception as e:
|
||||
self._handle_litellm_error(e)
|
||||
|
||||
if isinstance(tokens, bool) or not isinstance(tokens, int) or tokens < 0:
|
||||
raise errors.RequesterError(f'token counter returned invalid value: {tokens!r}')
|
||||
return tokens
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
|
||||
@@ -77,7 +77,7 @@ def make_session(
|
||||
}
|
||||
authorized_operations: dict[str, dict[str, set[str]]] = {
|
||||
'model': {
|
||||
m.get('model_id'): set(m.get('operations') or ['invoke', 'stream', 'rerank'])
|
||||
m.get('model_id'): set(m.get('operations') or ['invoke', 'stream', 'rerank', 'count_tokens'])
|
||||
for m in res.get('models', [])
|
||||
if m.get('model_id')
|
||||
},
|
||||
|
||||
@@ -14,7 +14,7 @@ from langbot.pkg.agent.runner.resource_builder import AgentResourceBuilder
|
||||
|
||||
RUNNER_ID = 'plugin:test/runner/default'
|
||||
FULL_PERMISSIONS = {
|
||||
'models': ['invoke', 'stream', 'rerank'],
|
||||
'models': ['count_tokens', 'invoke', 'stream', 'rerank'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
'history': ['page', 'search'],
|
||||
@@ -139,9 +139,24 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'primary', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'fallback', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{'model_id': 'aux', 'model_type': 'llm', 'provider': 'aux-provider', 'operations': ['invoke', 'stream']},
|
||||
{
|
||||
'model_id': 'primary',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream', 'count_tokens'],
|
||||
},
|
||||
{
|
||||
'model_id': 'fallback',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream', 'count_tokens'],
|
||||
},
|
||||
{
|
||||
'model_id': 'aux',
|
||||
'model_type': 'llm',
|
||||
'provider': 'aux-provider',
|
||||
'operations': ['invoke', 'stream', 'count_tokens'],
|
||||
},
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
|
||||
]
|
||||
|
||||
@@ -189,7 +204,12 @@ async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app):
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'llm', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{
|
||||
'model_id': 'llm',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream', 'count_tokens'],
|
||||
},
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
|
||||
]
|
||||
|
||||
@@ -222,7 +242,12 @@ async def test_build_resources_accepts_dynamic_form_type_aliases(app):
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'llm_alias', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
{
|
||||
'model_id': 'llm_alias',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream', 'count_tokens'],
|
||||
},
|
||||
]
|
||||
assert resources['knowledge_bases'] == [
|
||||
{'kb_id': 'kb_alias', 'kb_name': 'name-kb_alias', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
|
||||
|
||||
@@ -615,6 +615,94 @@ class TestAgentRunProxyActions:
|
||||
assert response.data['usage'] == usage
|
||||
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens_validates_run_authorization_and_calls_provider(self, app):
|
||||
"""COUNT_TOKENS is run-scoped and forwards messages/tools to the model requester."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_count_tokens'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[906] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=906,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(
|
||||
models=[{'model_id': 'llm_count_001', 'operations': ['count_tokens']}],
|
||||
),
|
||||
)
|
||||
|
||||
requester = SimpleNamespace(count_tokens=AsyncMock(return_value=37))
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={'temperature': 0.2}),
|
||||
provider=SimpleNamespace(requester=requester),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.COUNT_TOKENS.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_count_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
'funcs': [{
|
||||
'name': 'search',
|
||||
'human_desc': 'Search',
|
||||
'description': 'Search',
|
||||
'parameters': {'type': 'object'},
|
||||
}],
|
||||
'extra_args': {'temperature': 0.7},
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {'tokens': 37}
|
||||
requester.count_tokens.assert_awaited_once()
|
||||
kwargs = requester.count_tokens.await_args.kwargs
|
||||
assert kwargs['model'] is model
|
||||
assert kwargs['messages'][0].content == 'hello'
|
||||
assert [tool.name for tool in kwargs['funcs']] == ['search']
|
||||
assert kwargs['extra_args'] == {'temperature': 0.7}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens_rejects_model_without_operation(self, app):
|
||||
"""COUNT_TOKENS requires the explicit model operation in the run snapshot."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_count_tokens_denied'
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=None,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(
|
||||
models=[{'model_id': 'llm_count_002', 'operations': ['invoke']}],
|
||||
),
|
||||
)
|
||||
|
||||
runtime_handler = make_handler(app)
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.COUNT_TOKENS.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_count_002',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code != 0
|
||||
assert 'operation count_tokens' in response.message
|
||||
app.model_mgr.get_model_by_uuid.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_restores_query_and_options(self, app):
|
||||
"""INVOKE_LLM_STREAM applies the same host context as non-streaming calls."""
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
"""Unit tests for provider_specific_fields round-trip in LiteLLMRequester.
|
||||
"""Unit tests for LiteLLMRequester message/tool conversion.
|
||||
|
||||
This tests the fix for GitHub issue #1899: Gemini requires thought_signature
|
||||
to be preserved across tool call rounds for function calls to work correctly.
|
||||
This includes provider_specific_fields round-trip coverage for GitHub issue
|
||||
#1899 and token counting preflight behavior for AgentRunner context budgeting.
|
||||
"""
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
|
||||
from langbot.pkg.provider.modelmgr import requester as model_requester
|
||||
from langbot.pkg.provider.modelmgr.requesters.litellmchat import LiteLLMRequester
|
||||
|
||||
|
||||
@@ -14,6 +20,84 @@ def _make_requester() -> LiteLLMRequester:
|
||||
return LiteLLMRequester.__new__(LiteLLMRequester)
|
||||
|
||||
|
||||
def _make_configured_requester() -> LiteLLMRequester:
|
||||
req = LiteLLMRequester.__new__(LiteLLMRequester)
|
||||
req.requester_cfg = {
|
||||
'base_url': '',
|
||||
'timeout': 120,
|
||||
'custom_llm_provider': 'openai',
|
||||
'drop_params': False,
|
||||
'num_retries': 0,
|
||||
'api_version': '',
|
||||
}
|
||||
req.ap = SimpleNamespace(
|
||||
tool_mgr=SimpleNamespace(
|
||||
generate_tools_for_openai=AsyncMock(
|
||||
return_value=[
|
||||
{
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'search',
|
||||
'description': 'Search',
|
||||
'parameters': {'type': 'object'},
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
return req
|
||||
|
||||
|
||||
def _make_runtime_model() -> model_requester.RuntimeLLMModel:
|
||||
provider = SimpleNamespace(token_mgr=SimpleNamespace(get_token=Mock(return_value='sk-test')))
|
||||
return SimpleNamespace(
|
||||
model_entity=SimpleNamespace(
|
||||
name='gpt-4.1',
|
||||
extra_args={'temperature': 0.2},
|
||||
),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_tokens_uses_litellm_counter_with_request_messages_and_tools():
|
||||
"""Token preflight uses the same LiteLLM request shape as chat completion."""
|
||||
req = _make_configured_requester()
|
||||
model = _make_runtime_model()
|
||||
tool = resource_tool.LLMTool(
|
||||
name='search',
|
||||
human_desc='Search',
|
||||
description='Search',
|
||||
parameters={'type': 'object'},
|
||||
func=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
with patch('langbot.pkg.provider.modelmgr.requesters.litellmchat.litellm.token_counter', return_value=42) as counter:
|
||||
tokens = await req.count_tokens(
|
||||
model=model,
|
||||
messages=[
|
||||
provider_message.Message(
|
||||
role='user',
|
||||
content=[
|
||||
provider_message.ContentElement(type='text', text='hello'),
|
||||
provider_message.ContentElement(type='file_url', file_url='https://example.test/a.pdf'),
|
||||
],
|
||||
)
|
||||
],
|
||||
funcs=[tool],
|
||||
extra_args={'presence_penalty': 0.1},
|
||||
)
|
||||
|
||||
assert tokens == 42
|
||||
counter.assert_called_once()
|
||||
kwargs = counter.call_args.kwargs
|
||||
assert kwargs['model'] == 'openai/gpt-4.1'
|
||||
assert kwargs['messages'] == [{'role': 'user', 'content': [{'type': 'text', 'text': 'hello'}]}]
|
||||
assert kwargs['tools'][0]['function']['name'] == 'search'
|
||||
assert kwargs['tool_choice'] == 'auto'
|
||||
|
||||
|
||||
def test_convert_messages_preserves_tool_call_provider_specific_fields():
|
||||
"""Tool calls should retain provider_specific_fields through _convert_messages."""
|
||||
req = _make_requester()
|
||||
|
||||
Reference in New Issue
Block a user