Propagate agent runner model usage context

This commit is contained in:
huanghuoguoguo
2026-06-14 07:41:57 +08:00
parent 1153433693
commit 09adf4c541
9 changed files with 507 additions and 27 deletions
@@ -2,6 +2,7 @@
from __future__ import annotations
import pytest
from types import SimpleNamespace
from unittest.mock import MagicMock, AsyncMock, patch
# SDK imports for validation
@@ -174,6 +175,87 @@ class TestContextValidation:
# Verify input
assert validated.input.text == "Hello world"
@pytest.mark.asyncio
async def test_build_context_from_event_populates_model_context_window(self):
"""Runtime metadata should expose the selected LLM model context window."""
mock_app = self._make_mock_app()
mock_app.model_mgr = MagicMock()
mock_app.model_mgr.get_model_by_uuid = AsyncMock(
return_value=SimpleNamespace(
model_entity=SimpleNamespace(context_length=128000),
)
)
builder = AgentRunContextBuilder(mock_app)
event = self._make_event_envelope()
binding = self._make_binding()
resources = self._make_resources()
resources['models'] = [
{
'model_id': 'rerank-model',
'model_type': 'rerank',
'provider': 'test-provider',
'operations': ['rerank'],
},
{
'model_id': 'llm-model',
'model_type': 'llm',
'provider': 'test-provider',
'operations': ['invoke', 'stream'],
},
]
descriptor = self._make_descriptor()
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
mock_store = AsyncMock()
mock_store.build_snapshot_from_event = AsyncMock(return_value={
'conversation': {},
'actor': {},
'subject': {},
'runner': {},
})
mock_get_store.return_value = mock_store
context_dict = await builder.build_context_from_event(
event=event,
binding=binding,
descriptor=descriptor,
resources=resources,
)
assert context_dict['runtime']['metadata']['model_context_window_tokens'] == 128000
mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('llm-model')
@pytest.mark.asyncio
async def test_model_context_window_uses_primary_llm_only(self):
"""Fallback model windows should not replace missing primary model metadata."""
mock_app = self._make_mock_app()
mock_app.model_mgr = MagicMock()
mock_app.model_mgr.get_model_by_uuid = AsyncMock(
return_value=SimpleNamespace(
model_entity=SimpleNamespace(context_length=None),
)
)
builder = AgentRunContextBuilder(mock_app)
resources = self._make_resources()
resources['models'] = [
{
'model_id': 'primary-model',
'model_type': 'llm',
'provider': 'test-provider',
'operations': ['invoke', 'stream'],
},
{
'model_id': 'fallback-model',
'model_type': 'llm',
'provider': 'test-provider',
'operations': ['invoke', 'stream'],
},
]
assert await builder._build_model_context_window_tokens(resources) is None
mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('primary-model')
@pytest.mark.asyncio
async def test_build_context_preserves_subject_data_for_non_message_events(self):
"""Non-message EBA events keep subject.data instead of relying on message text."""
@@ -388,6 +388,7 @@ class TestAgentRunProxyActions:
def query(remove_think=True):
return SimpleNamespace(
pipeline_config={'output': {'misc': {'remove-think': remove_think}}},
variables={},
prompt=SimpleNamespace(
messages=[provider_message.Message(role='system', content='effective prompt')]
),
@@ -488,6 +489,60 @@ class TestAgentRunProxyActions:
assert kwargs['remove_think'] is True
assert [tool.name for tool in kwargs['funcs']] == ['search']
@pytest.mark.asyncio
async def test_invoke_llm_returns_provider_usage(self, app):
"""INVOKE_LLM includes optional provider usage in the action response."""
from langbot.pkg.agent.runner.session_registry import get_session_registry
from langbot.pkg.provider.modelmgr import requester as model_requester
usage = {
'prompt_tokens': 11,
'completion_tokens': 7,
'total_tokens': 18,
'prompt_tokens_details': {'cached_tokens': 3},
}
class UsageProvider:
async def invoke_llm(self, **kwargs):
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
return provider_message.Message(role='assistant', content='ok')
run_id = 'run_proxy_invoke_llm_usage'
query = self.query()
app.query_pool.cached_queries[905] = query
registry = get_session_registry()
await registry.unregister(run_id)
await registry.register(
run_id=run_id,
runner_id='plugin:test/runner/default',
query_id=905,
plugin_identity='test/runner',
resources=make_agent_resources(models=[{'model_id': 'llm_usage_001'}]),
)
model = SimpleNamespace(
model_entity=SimpleNamespace(abilities=[], extra_args={}),
provider=UsageProvider(),
)
app.model_mgr.get_model_by_uuid.return_value = model
runtime_handler = make_handler(app)
try:
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM.value]({
'run_id': run_id,
'caller_plugin_identity': 'test/runner',
'llm_model_uuid': 'llm_usage_001',
'messages': [{'role': 'user', 'content': 'hello'}],
})
finally:
await registry.unregister(run_id)
assert response.code == 0
assert response.data['message']['content'] == 'ok'
assert response.data['usage'] == usage
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
@pytest.mark.asyncio
async def test_invoke_llm_stream_restores_query_and_options(self, app):
"""INVOKE_LLM_STREAM applies the same host context as non-streaming calls."""
@@ -598,6 +653,63 @@ class TestAgentRunProxyActions:
assert [response.code for response in responses] == [0, 0]
assert [response.data['chunk']['content'] for response in responses] == ['ok', ' done']
@pytest.mark.asyncio
async def test_invoke_llm_stream_returns_provider_usage_event(self, app):
"""INVOKE_LLM_STREAM emits a final usage-only action response when available."""
from langbot.pkg.agent.runner.session_registry import get_session_registry
from langbot.pkg.provider.modelmgr import requester as model_requester
usage = {
'prompt_tokens': 9,
'completion_tokens': 4,
'total_tokens': 13,
'prompt_tokens_details': {'cached_tokens': 2},
}
class StreamProvider:
async def invoke_llm_stream(self, **kwargs):
yield provider_message.MessageChunk(role='assistant', content='ok')
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
run_id = 'run_proxy_invoke_llm_stream_usage'
query = self.query()
app.query_pool.cached_queries[906] = query
registry = get_session_registry()
await registry.unregister(run_id)
await registry.register(
run_id=run_id,
runner_id='plugin:test/runner/default',
query_id=906,
plugin_identity='test/runner',
resources=make_agent_resources(models=[{'model_id': 'llm_stream_usage_001'}]),
)
model = SimpleNamespace(
model_entity=SimpleNamespace(abilities=[], extra_args={}),
provider=StreamProvider(),
)
app.model_mgr.get_model_by_uuid.return_value = model
runtime_handler = make_handler(app)
responses = []
try:
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
'run_id': run_id,
'caller_plugin_identity': 'test/runner',
'llm_model_uuid': 'llm_stream_usage_001',
'messages': [{'role': 'user', 'content': 'hello'}],
})
async for response in stream:
responses.append(response)
finally:
await registry.unregister(run_id)
assert [response.code for response in responses] == [0, 0]
assert responses[0].data['chunk']['content'] == 'ok'
assert responses[1].data == {'usage': usage}
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
@pytest.mark.asyncio
async def test_call_tool_passes_current_query(self, app):
"""CALL_TOOL passes the current Query back into tool execution."""
@@ -115,6 +115,15 @@ class TestExtractUsage:
assert result['prompt_tokens'] == 0
assert result['completion_tokens'] == 0
def test_extract_usage_without_provider_usage(self):
"""Missing provider usage is not treated as authoritative zero usage."""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
response = Mock()
response.usage = None
assert requester._extract_usage(response) is None
class TestNormalizeUsage:
"""Test _normalize_usage helper covering real-world usage shapes"""
@@ -131,6 +140,22 @@ class TestNormalizeUsage:
)
assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20}
def test_preserves_token_details(self):
"""Provider token details such as cache counters are preserved."""
result = litellmchat.LiteLLMRequester._normalize_usage(
{
'prompt_tokens': 12,
'completion_tokens': 8,
'total_tokens': 20,
'prompt_tokens_details': {'cached_tokens': 7},
'completion_tokens_details': {'reasoning_tokens': 3},
}
)
assert result['prompt_tokens'] == 12
assert result['prompt_tokens_details'] == {'cached_tokens': 7}
assert result['completion_tokens_details'] == {'reasoning_tokens': 3}
def test_missing_total_is_derived(self):
"""When total_tokens is absent/zero it is derived from prompt + completion"""
usage = Mock()
@@ -299,6 +299,59 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
assert result.role == 'assistant'
@pytest.mark.asyncio
async def test_runtime_provider_invoke_llm_stashes_usage(runtime_provider, runtime_llm_model):
"""RuntimeProvider preserves requester usage for upstream action handlers."""
provider = runtime_provider
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
query = pipeline_query.Query.model_construct(
query_id='test-query-usage',
launcher_type='person',
launcher_id=12345,
sender_id=12345,
message_chain=None,
message_event=None,
adapter=None,
pipeline_uuid='pipeline-uuid',
bot_uuid='bot-uuid',
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
session=None,
prompt=None,
messages=[],
user_message=None,
use_funcs=[],
use_llm_model_uuid=None,
variables={},
resp_messages=[],
resp_message_chain=None,
current_stage_name=None,
)
usage = {
'prompt_tokens': 11,
'completion_tokens': 7,
'total_tokens': 18,
'prompt_tokens_details': {'cached_tokens': 3},
}
provider.requester.invoke_llm = AsyncMock(
return_value=(
provider_message.Message(role='assistant', content='ok'),
usage,
)
)
result = await provider.invoke_llm(
query,
runtime_llm_model,
[provider_message.Message(role='user', content='Hello')],
)
assert result.content == 'ok'
assert query.variables[requester.LLM_USAGE_QUERY_VARIABLE] == usage
@pytest.mark.asyncio
async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider, runtime_llm_model):
"""Test RuntimeProvider.invoke_llm_stream yields chunks from requester."""
@@ -340,6 +393,62 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
assert chunks[0].role == 'assistant'
@pytest.mark.asyncio
async def test_runtime_provider_invoke_llm_stream_stashes_usage(runtime_provider, runtime_llm_model):
"""RuntimeProvider transfers captured stream usage to the public query usage key."""
provider = runtime_provider
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
query = pipeline_query.Query.model_construct(
query_id='test-stream-usage',
launcher_type='person',
launcher_id=12345,
sender_id=12345,
message_chain=None,
message_event=None,
adapter=None,
pipeline_uuid='pipeline-uuid',
bot_uuid='bot-uuid',
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
session=None,
prompt=None,
messages=[],
user_message=None,
use_funcs=[],
use_llm_model_uuid=None,
variables={},
resp_messages=[],
resp_message_chain=None,
current_stage_name=None,
)
usage = {
'prompt_tokens': 13,
'completion_tokens': 2,
'total_tokens': 15,
}
async def fake_stream(**kwargs):
kwargs['query'].variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage
yield provider_message.MessageChunk(role='assistant', content='ok')
provider.requester.invoke_llm_stream = fake_stream
chunks = [
chunk
async for chunk in provider.invoke_llm_stream(
query,
runtime_llm_model,
[provider_message.Message(role='user', content='Hello')],
)
]
assert len(chunks) == 1
assert query.variables[requester.LLM_USAGE_QUERY_VARIABLE] == usage
assert requester.STREAM_USAGE_QUERY_VARIABLE not in query.variables
@pytest.mark.asyncio
async def test_runtime_provider_invoke_embedding_returns_vectors(runtime_provider, runtime_embedding_model):
"""Test RuntimeProvider.invoke_embedding returns embedding vectors."""