feat(agent-runner): support scoped token counting

This commit is contained in:
huanghuoguoguo
2026-06-27 01:31:08 +08:00
parent ae49753f74
commit d0f6fe2cec
10 changed files with 302 additions and 15 deletions
+1 -1
View File
@@ -77,7 +77,7 @@ def make_session(
}
authorized_operations: dict[str, dict[str, set[str]]] = {
'model': {
m.get('model_id'): set(m.get('operations') or ['invoke', 'stream', 'rerank'])
m.get('model_id'): set(m.get('operations') or ['invoke', 'stream', 'rerank', 'count_tokens'])
for m in res.get('models', [])
if m.get('model_id')
},
@@ -14,7 +14,7 @@ from langbot.pkg.agent.runner.resource_builder import AgentResourceBuilder
RUNNER_ID = 'plugin:test/runner/default'
FULL_PERMISSIONS = {
'models': ['invoke', 'stream', 'rerank'],
'models': ['count_tokens', 'invoke', 'stream', 'rerank'],
'tools': ['detail', 'call'],
'knowledge_bases': ['list', 'retrieve'],
'history': ['page', 'search'],
@@ -139,9 +139,24 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app
resources = await build_resources(app, query, descriptor)
assert resources['models'] == [
{'model_id': 'primary', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
{'model_id': 'fallback', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
{'model_id': 'aux', 'model_type': 'llm', 'provider': 'aux-provider', 'operations': ['invoke', 'stream']},
{
'model_id': 'primary',
'model_type': 'llm',
'provider': 'test-provider',
'operations': ['invoke', 'stream', 'count_tokens'],
},
{
'model_id': 'fallback',
'model_type': 'llm',
'provider': 'test-provider',
'operations': ['invoke', 'stream', 'count_tokens'],
},
{
'model_id': 'aux',
'model_type': 'llm',
'provider': 'aux-provider',
'operations': ['invoke', 'stream', 'count_tokens'],
},
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
]
@@ -189,7 +204,12 @@ async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app):
resources = await build_resources(app, query, descriptor)
assert resources['models'] == [
{'model_id': 'llm', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
{
'model_id': 'llm',
'model_type': 'llm',
'provider': 'test-provider',
'operations': ['invoke', 'stream', 'count_tokens'],
},
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider', 'operations': ['rerank']},
]
@@ -222,7 +242,12 @@ async def test_build_resources_accepts_dynamic_form_type_aliases(app):
resources = await build_resources(app, query, descriptor)
assert resources['models'] == [
{'model_id': 'llm_alias', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
{
'model_id': 'llm_alias',
'model_type': 'llm',
'provider': 'test-provider',
'operations': ['invoke', 'stream', 'count_tokens'],
},
]
assert resources['knowledge_bases'] == [
{'kb_id': 'kb_alias', 'kb_name': 'name-kb_alias', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
@@ -615,6 +615,94 @@ class TestAgentRunProxyActions:
assert response.data['usage'] == usage
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
@pytest.mark.asyncio
async def test_count_tokens_validates_run_authorization_and_calls_provider(self, app):
"""COUNT_TOKENS is run-scoped and forwards messages/tools to the model requester."""
from langbot.pkg.agent.runner.session_registry import get_session_registry
run_id = 'run_proxy_count_tokens'
query = self.query()
app.query_pool.cached_queries[906] = query
registry = get_session_registry()
await registry.unregister(run_id)
await registry.register(
run_id=run_id,
runner_id='plugin:test/runner/default',
query_id=906,
plugin_identity='test/runner',
resources=make_agent_resources(
models=[{'model_id': 'llm_count_001', 'operations': ['count_tokens']}],
),
)
requester = SimpleNamespace(count_tokens=AsyncMock(return_value=37))
model = SimpleNamespace(
model_entity=SimpleNamespace(abilities=[], extra_args={'temperature': 0.2}),
provider=SimpleNamespace(requester=requester),
)
app.model_mgr.get_model_by_uuid.return_value = model
runtime_handler = make_handler(app)
try:
response = await runtime_handler.actions[PluginToRuntimeAction.COUNT_TOKENS.value]({
'run_id': run_id,
'caller_plugin_identity': 'test/runner',
'llm_model_uuid': 'llm_count_001',
'messages': [{'role': 'user', 'content': 'hello'}],
'funcs': [{
'name': 'search',
'human_desc': 'Search',
'description': 'Search',
'parameters': {'type': 'object'},
}],
'extra_args': {'temperature': 0.7},
})
finally:
await registry.unregister(run_id)
assert response.code == 0
assert response.data == {'tokens': 37}
requester.count_tokens.assert_awaited_once()
kwargs = requester.count_tokens.await_args.kwargs
assert kwargs['model'] is model
assert kwargs['messages'][0].content == 'hello'
assert [tool.name for tool in kwargs['funcs']] == ['search']
assert kwargs['extra_args'] == {'temperature': 0.7}
@pytest.mark.asyncio
async def test_count_tokens_rejects_model_without_operation(self, app):
"""COUNT_TOKENS requires the explicit model operation in the run snapshot."""
from langbot.pkg.agent.runner.session_registry import get_session_registry
run_id = 'run_proxy_count_tokens_denied'
registry = get_session_registry()
await registry.unregister(run_id)
await registry.register(
run_id=run_id,
runner_id='plugin:test/runner/default',
query_id=None,
plugin_identity='test/runner',
resources=make_agent_resources(
models=[{'model_id': 'llm_count_002', 'operations': ['invoke']}],
),
)
runtime_handler = make_handler(app)
try:
response = await runtime_handler.actions[PluginToRuntimeAction.COUNT_TOKENS.value]({
'run_id': run_id,
'caller_plugin_identity': 'test/runner',
'llm_model_uuid': 'llm_count_002',
'messages': [{'role': 'user', 'content': 'hello'}],
})
finally:
await registry.unregister(run_id)
assert response.code != 0
assert 'operation count_tokens' in response.message
app.model_mgr.get_model_by_uuid.assert_not_awaited()
@pytest.mark.asyncio
async def test_invoke_llm_stream_restores_query_and_options(self, app):
"""INVOKE_LLM_STREAM applies the same host context as non-streaming calls."""
@@ -1,11 +1,17 @@
"""Unit tests for provider_specific_fields round-trip in LiteLLMRequester.
"""Unit tests for LiteLLMRequester message/tool conversion.
This tests the fix for GitHub issue #1899: Gemini requires thought_signature
to be preserved across tool call rounds for function calls to work correctly.
This includes provider_specific_fields round-trip coverage for GitHub issue
#1899 and token counting preflight behavior for AgentRunner context budgeting.
"""
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
import pytest
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from langbot.pkg.provider.modelmgr import requester as model_requester
from langbot.pkg.provider.modelmgr.requesters.litellmchat import LiteLLMRequester
@@ -14,6 +20,84 @@ def _make_requester() -> LiteLLMRequester:
return LiteLLMRequester.__new__(LiteLLMRequester)
def _make_configured_requester() -> LiteLLMRequester:
req = LiteLLMRequester.__new__(LiteLLMRequester)
req.requester_cfg = {
'base_url': '',
'timeout': 120,
'custom_llm_provider': 'openai',
'drop_params': False,
'num_retries': 0,
'api_version': '',
}
req.ap = SimpleNamespace(
tool_mgr=SimpleNamespace(
generate_tools_for_openai=AsyncMock(
return_value=[
{
'type': 'function',
'function': {
'name': 'search',
'description': 'Search',
'parameters': {'type': 'object'},
},
}
]
)
)
)
return req
def _make_runtime_model() -> model_requester.RuntimeLLMModel:
provider = SimpleNamespace(token_mgr=SimpleNamespace(get_token=Mock(return_value='sk-test')))
return SimpleNamespace(
model_entity=SimpleNamespace(
name='gpt-4.1',
extra_args={'temperature': 0.2},
),
provider=provider,
)
@pytest.mark.asyncio
async def test_count_tokens_uses_litellm_counter_with_request_messages_and_tools():
"""Token preflight uses the same LiteLLM request shape as chat completion."""
req = _make_configured_requester()
model = _make_runtime_model()
tool = resource_tool.LLMTool(
name='search',
human_desc='Search',
description='Search',
parameters={'type': 'object'},
func=lambda **kwargs: None,
)
with patch('langbot.pkg.provider.modelmgr.requesters.litellmchat.litellm.token_counter', return_value=42) as counter:
tokens = await req.count_tokens(
model=model,
messages=[
provider_message.Message(
role='user',
content=[
provider_message.ContentElement(type='text', text='hello'),
provider_message.ContentElement(type='file_url', file_url='https://example.test/a.pdf'),
],
)
],
funcs=[tool],
extra_args={'presence_penalty': 0.1},
)
assert tokens == 42
counter.assert_called_once()
kwargs = counter.call_args.kwargs
assert kwargs['model'] == 'openai/gpt-4.1'
assert kwargs['messages'] == [{'role': 'user', 'content': [{'type': 'text', 'text': 'hello'}]}]
assert kwargs['tools'][0]['function']['name'] == 'search'
assert kwargs['tool_choice'] == 'auto'
def test_convert_messages_preserves_tool_call_provider_specific_fields():
"""Tool calls should retain provider_specific_fields through _convert_messages."""
req = _make_requester()