Propagate agent runner model usage context

This commit is contained in:
huanghuoguoguo
2026-06-14 07:41:57 +08:00
parent 1153433693
commit 09adf4c541
9 changed files with 507 additions and 27 deletions
@@ -388,6 +388,7 @@ class TestAgentRunProxyActions:
def query(remove_think=True):
return SimpleNamespace(
pipeline_config={'output': {'misc': {'remove-think': remove_think}}},
variables={},
prompt=SimpleNamespace(
messages=[provider_message.Message(role='system', content='effective prompt')]
),
@@ -488,6 +489,60 @@ class TestAgentRunProxyActions:
assert kwargs['remove_think'] is True
assert [tool.name for tool in kwargs['funcs']] == ['search']
@pytest.mark.asyncio
async def test_invoke_llm_returns_provider_usage(self, app):
"""INVOKE_LLM includes optional provider usage in the action response."""
from langbot.pkg.agent.runner.session_registry import get_session_registry
from langbot.pkg.provider.modelmgr import requester as model_requester
usage = {
'prompt_tokens': 11,
'completion_tokens': 7,
'total_tokens': 18,
'prompt_tokens_details': {'cached_tokens': 3},
}
class UsageProvider:
async def invoke_llm(self, **kwargs):
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
return provider_message.Message(role='assistant', content='ok')
run_id = 'run_proxy_invoke_llm_usage'
query = self.query()
app.query_pool.cached_queries[905] = query
registry = get_session_registry()
await registry.unregister(run_id)
await registry.register(
run_id=run_id,
runner_id='plugin:test/runner/default',
query_id=905,
plugin_identity='test/runner',
resources=make_agent_resources(models=[{'model_id': 'llm_usage_001'}]),
)
model = SimpleNamespace(
model_entity=SimpleNamespace(abilities=[], extra_args={}),
provider=UsageProvider(),
)
app.model_mgr.get_model_by_uuid.return_value = model
runtime_handler = make_handler(app)
try:
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM.value]({
'run_id': run_id,
'caller_plugin_identity': 'test/runner',
'llm_model_uuid': 'llm_usage_001',
'messages': [{'role': 'user', 'content': 'hello'}],
})
finally:
await registry.unregister(run_id)
assert response.code == 0
assert response.data['message']['content'] == 'ok'
assert response.data['usage'] == usage
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
@pytest.mark.asyncio
async def test_invoke_llm_stream_restores_query_and_options(self, app):
"""INVOKE_LLM_STREAM applies the same host context as non-streaming calls."""
@@ -598,6 +653,63 @@ class TestAgentRunProxyActions:
assert [response.code for response in responses] == [0, 0]
assert [response.data['chunk']['content'] for response in responses] == ['ok', ' done']
@pytest.mark.asyncio
async def test_invoke_llm_stream_returns_provider_usage_event(self, app):
"""INVOKE_LLM_STREAM emits a final usage-only action response when available."""
from langbot.pkg.agent.runner.session_registry import get_session_registry
from langbot.pkg.provider.modelmgr import requester as model_requester
usage = {
'prompt_tokens': 9,
'completion_tokens': 4,
'total_tokens': 13,
'prompt_tokens_details': {'cached_tokens': 2},
}
class StreamProvider:
async def invoke_llm_stream(self, **kwargs):
yield provider_message.MessageChunk(role='assistant', content='ok')
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
run_id = 'run_proxy_invoke_llm_stream_usage'
query = self.query()
app.query_pool.cached_queries[906] = query
registry = get_session_registry()
await registry.unregister(run_id)
await registry.register(
run_id=run_id,
runner_id='plugin:test/runner/default',
query_id=906,
plugin_identity='test/runner',
resources=make_agent_resources(models=[{'model_id': 'llm_stream_usage_001'}]),
)
model = SimpleNamespace(
model_entity=SimpleNamespace(abilities=[], extra_args={}),
provider=StreamProvider(),
)
app.model_mgr.get_model_by_uuid.return_value = model
runtime_handler = make_handler(app)
responses = []
try:
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
'run_id': run_id,
'caller_plugin_identity': 'test/runner',
'llm_model_uuid': 'llm_stream_usage_001',
'messages': [{'role': 'user', 'content': 'hello'}],
})
async for response in stream:
responses.append(response)
finally:
await registry.unregister(run_id)
assert [response.code for response in responses] == [0, 0]
assert responses[0].data['chunk']['content'] == 'ok'
assert responses[1].data == {'usage': usage}
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
@pytest.mark.asyncio
async def test_call_tool_passes_current_query(self, app):
"""CALL_TOOL passes the current Query back into tool execution."""