diff --git a/docs/agent-runner-pluginization/PROTOCOL_V1.md b/docs/agent-runner-pluginization/PROTOCOL_V1.md index 7e504498..c40d13c0 100644 --- a/docs/agent-runner-pluginization/PROTOCOL_V1.md +++ b/docs/agent-runner-pluginization/PROTOCOL_V1.md @@ -475,8 +475,11 @@ Host 必须校验 `state.updated` 的 scope、key、value 大小和 JSON 可序 ```python # Model await api.invoke_llm(llm_model_uuid, messages, funcs=None, extra_args=None) +await api.invoke_llm_with_usage(llm_model_uuid, messages, funcs=None, extra_args=None) async for chunk in api.invoke_llm_stream(llm_model_uuid, messages, funcs=None, extra_args=None): ... +async for event in api.invoke_llm_stream_events(llm_model_uuid, messages, funcs=None, extra_args=None): + ... await api.invoke_rerank(rerank_model_id, query, documents, top_k=None) # Tool @@ -519,6 +522,16 @@ await api.get_langbot_version() `llm_model_uuid`,wire payload 字段也是 `llm_model_uuid`。该值对 runner 仍是 opaque identifier,不应解析其内部格式。 +`invoke_llm()` 和 `invoke_llm_stream()` 保持兼容:前者返回 `Message`,后者只 +yield `MessageChunk`。需要 provider 真实 token 计量的 runner 应使用 +`invoke_llm_with_usage()` 或 `invoke_llm_stream_events()`。Host response 可在 +原有 `{message: ...}` / `{chunk: ...}` 外额外携带可选 `usage` 字段;streaming +场景允许在所有 chunk 之后追加一个 usage-only event。`usage` 至少保留 +OpenAI-compatible 的 `prompt_tokens`、`completion_tokens`、`total_tokens`, +若 provider 返回 `prompt_tokens_details` / `completion_tokens_details` 或 +cache token counters,Host / SDK 不应丢弃这些字段。没有 usage 的 provider +必须继续返回成功响应,SDK 将 usage 置为 `None`。 + `get_prompt()` 返回当前 query-backed run 的 Host effective prompt messages: `list[Message]` 的 JSON 形式。该能力只在 `ctx.context.available_apis.prompt_get` 为 true 时可用;没有 query 缓存、prompt 已过期或非 query entry run 时 Host diff --git a/src/langbot/pkg/agent/runner/context_builder.py b/src/langbot/pkg/agent/runner/context_builder.py index 8f51e3e6..665fce64 100644 --- a/src/langbot/pkg/agent/runner/context_builder.py +++ b/src/langbot/pkg/agent/runner/context_builder.py @@ -179,6 +179,52 @@ class AgentRunContextBuilder: def __init__(self, ap: app.Application): self.ap = ap + @staticmethod + def _positive_int(value: typing.Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int) and value > 0: + return value + if isinstance(value, str) and value.isdigit(): + parsed_value = int(value) + if parsed_value > 0: + return parsed_value + return None + + @staticmethod + def _is_llm_model_resource(model_resource: ModelResource) -> bool: + operations = model_resource.get('operations') + if isinstance(operations, list) and operations: + return bool({'invoke', 'stream'} & {str(operation) for operation in operations}) + return model_resource.get('model_type') != 'rerank' + + async def _build_model_context_window_tokens(self, resources: AgentResources) -> int | None: + model_mgr = getattr(self.ap, 'model_mgr', None) + if model_mgr is None: + return None + + for model_resource in resources.get('models', []): + if not self._is_llm_model_resource(model_resource): + continue + + model_uuid = model_resource.get('model_id') + if not isinstance(model_uuid, str) or not model_uuid: + continue + + try: + model = await model_mgr.get_model_by_uuid(model_uuid) + except Exception as exc: + logger = getattr(self.ap, 'logger', None) + if logger is not None: + logger.debug(f'Failed to resolve model context window for {model_uuid}: {exc}') + continue + + model_entity = getattr(model, 'model_entity', None) + context_length = self._positive_int(getattr(model_entity, 'context_length', None)) + return context_length + + return None + async def build_context_from_event( self, event: AgentEventEnvelope, @@ -270,6 +316,8 @@ class AgentRunContextBuilder: persistent_state_store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine()) state: AgentRunState = await persistent_state_store.build_snapshot_from_event(event, binding, descriptor) + model_context_window_tokens = await self._build_model_context_window_tokens(resources) + # Build runtime context runtime: AgentRuntimeContext = { 'langbot_version': self.ap.ver_mgr.get_current_version(), @@ -279,10 +327,7 @@ class AgentRunContextBuilder: 'bot_id': event.bot_id, 'workspace_id': event.workspace_id, 'streaming_supported': event.delivery.supports_streaming, - 'model_context_window_tokens': None, - # TODO(model-info): populate model_context_window_tokens after - # LiteLLM/model metadata lands. Runners fall back to their - # ctx.config until Host can provide the real window. + 'model_context_window_tokens': model_context_window_tokens, }, } diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index 41bf3f4f..6043aa2b 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -21,6 +21,7 @@ import langbot_plugin.api.entities.builtin.resource.tool as resource_tool from ..entity.persistence import plugin as persistence_plugin from ..entity.persistence import bstorage as persistence_bstorage +from ..provider.modelmgr import requester as model_requester from ..core import app from ..utils import constants @@ -43,6 +44,18 @@ def _make_rag_error_response(error: Exception, error_type: str, **extra_context) return handler.ActionResponse.error(message=message) +def _pop_query_llm_usage(query: Any) -> dict[str, Any] | None: + """Read provider usage stashed on a query by RuntimeProvider.""" + if query is None or not getattr(query, 'variables', None): + return None + usage = query.variables.pop(model_requester.LLM_USAGE_QUERY_VARIABLE, None) + if usage is None: + return None + if isinstance(usage, dict): + return dict(usage) + return None + + def _i18n_to_dict(value: Any) -> dict[str, Any]: """Convert SDK i18n values to plain dictionaries.""" if value is None: @@ -802,10 +815,20 @@ class RuntimeConnectionHandler(handler.Handler): remove_think=remove_think, ) + usage = None + if isinstance(result, tuple): + result, usage = result + if usage is None: + usage = _pop_query_llm_usage(query) + + response_data = { + 'message': result.model_dump(), + } + if usage is not None: + response_data['usage'] = usage + return handler.ActionResponse.success( - data={ - 'message': result.model_dump(), - }, + data=response_data, ) @self.action(PluginToRuntimeAction.INVOKE_LLM_STREAM) @@ -867,6 +890,13 @@ class RuntimeConnectionHandler(handler.Handler): 'chunk': chunk.model_dump(), }, ) + usage = _pop_query_llm_usage(query) + if usage is not None: + yield handler.ActionResponse.success( + data={ + 'usage': usage, + }, + ) @self.action(PluginToRuntimeAction.CALL_TOOL) async def call_tool(data: dict[str, Any]) -> handler.ActionResponse: diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index b673c758..377f7d4a 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -12,6 +12,19 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.provider.message as provider_message +LLM_USAGE_QUERY_VARIABLE = '_llm_usage' +STREAM_USAGE_QUERY_VARIABLE = '_stream_usage' + + +def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None: + """Store the latest provider usage on the query for upstream action handlers.""" + if query is None or not usage_info: + return + if query.variables is None: + query.variables = {} + query.variables[LLM_USAGE_QUERY_VARIABLE] = dict(usage_info) + + class RuntimeProvider: """运行时模型提供商""" @@ -67,6 +80,7 @@ class RuntimeProvider: if isinstance(result, tuple): msg, usage_info = result if usage_info: + _store_llm_usage(query, usage_info) input_tokens = usage_info.get('prompt_tokens', 0) output_tokens = usage_info.get('completion_tokens', 0) return msg @@ -146,11 +160,12 @@ class RuntimeProvider: if query: if query.variables is None: query.variables = {} - if '_stream_usage' in query.variables: - usage_info = query.variables['_stream_usage'] + if STREAM_USAGE_QUERY_VARIABLE in query.variables: + usage_info = query.variables[STREAM_USAGE_QUERY_VARIABLE] + _store_llm_usage(query, usage_info) input_tokens = usage_info.get('prompt_tokens', 0) output_tokens = usage_info.get('completion_tokens', 0) - del query.variables['_stream_usage'] + del query.variables[STREAM_USAGE_QUERY_VARIABLE] except Exception as e: status = 'error' error_message = str(e) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index 6b087916..6e5d53e8 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -250,32 +250,81 @@ class LiteLLMRequester(requester.ProviderAPIRequester): - dict with the same keys - missing ``total_tokens`` (derived from prompt + completion) - ``None`` / partially-populated usage (defaults to 0) + - provider-specific token details, including cache token counters """ - if usage is None: - return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0} + def _plain_value(value: typing.Any) -> typing.Any: + if value is None: + return None + if isinstance(value, dict): + return {k: _plain_value(v) for k, v in value.items() if v is not None} + if isinstance(value, (list, tuple)): + return [_plain_value(v) for v in value] - def _get(key: str) -> typing.Any: - if isinstance(usage, dict): - return usage.get(key) - return getattr(usage, key, None) + model_dump = getattr(value, 'model_dump', None) + if callable(model_dump): + try: + dumped = model_dump() + if isinstance(dumped, dict): + return _plain_value(dumped) + except Exception: + pass - prompt_tokens = _get('prompt_tokens') or 0 - completion_tokens = _get('completion_tokens') or 0 - total_tokens = _get('total_tokens') or 0 + return value + + def _usage_dict(value: typing.Any) -> dict[str, typing.Any]: + if value is None: + return {} + plain = _plain_value(value) + if isinstance(plain, dict): + return plain + + def _is_mock_attr(attr: typing.Any) -> bool: + return type(attr).__module__.startswith('unittest.mock') + + data: dict[str, typing.Any] = {} + for key in ( + 'prompt_tokens', + 'completion_tokens', + 'total_tokens', + 'prompt_tokens_details', + 'completion_tokens_details', + 'cache_creation_input_tokens', + 'cache_read_input_tokens', + 'input_token_details', + 'output_token_details', + ): + attr_value = getattr(value, key, None) + if attr_value is not None and not _is_mock_attr(attr_value): + data[key] = _plain_value(attr_value) + return data + + def _to_int(value: typing.Any) -> int: + try: + return int(value or 0) + except (TypeError, ValueError): + return 0 + + normalized = _usage_dict(usage) + + prompt_tokens = _to_int(normalized.get('prompt_tokens')) + completion_tokens = _to_int(normalized.get('completion_tokens')) + total_tokens = _to_int(normalized.get('total_tokens')) # Some providers omit total_tokens in streaming usage; derive it. if not total_tokens: total_tokens = prompt_tokens + completion_tokens - return { - 'prompt_tokens': int(prompt_tokens), - 'completion_tokens': int(completion_tokens), - 'total_tokens': int(total_tokens), - } + normalized['prompt_tokens'] = prompt_tokens + normalized['completion_tokens'] = completion_tokens + normalized['total_tokens'] = total_tokens + return normalized - def _extract_usage(self, response) -> dict: + def _extract_usage(self, response) -> dict | None: """Extract usage info from a non-streaming LiteLLM response.""" - return self._normalize_usage(getattr(response, 'usage', None)) + usage = getattr(response, 'usage', None) + if usage is None: + return None + return self._normalize_usage(usage) @staticmethod def _as_dict(value: typing.Any) -> dict: @@ -474,7 +523,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): if query is not None: if query.variables is None: query.variables = {} - query.variables['_stream_usage'] = usage_info + query.variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage_info if not hasattr(chunk, 'choices') or not chunk.choices: continue diff --git a/tests/unit_tests/agent/test_context_validation.py b/tests/unit_tests/agent/test_context_validation.py index b933c7c1..3cc36791 100644 --- a/tests/unit_tests/agent/test_context_validation.py +++ b/tests/unit_tests/agent/test_context_validation.py @@ -2,6 +2,7 @@ from __future__ import annotations import pytest +from types import SimpleNamespace from unittest.mock import MagicMock, AsyncMock, patch # SDK imports for validation @@ -174,6 +175,87 @@ class TestContextValidation: # Verify input assert validated.input.text == "Hello world" + @pytest.mark.asyncio + async def test_build_context_from_event_populates_model_context_window(self): + """Runtime metadata should expose the selected LLM model context window.""" + mock_app = self._make_mock_app() + mock_app.model_mgr = MagicMock() + mock_app.model_mgr.get_model_by_uuid = AsyncMock( + return_value=SimpleNamespace( + model_entity=SimpleNamespace(context_length=128000), + ) + ) + builder = AgentRunContextBuilder(mock_app) + + event = self._make_event_envelope() + binding = self._make_binding() + resources = self._make_resources() + resources['models'] = [ + { + 'model_id': 'rerank-model', + 'model_type': 'rerank', + 'provider': 'test-provider', + 'operations': ['rerank'], + }, + { + 'model_id': 'llm-model', + 'model_type': 'llm', + 'provider': 'test-provider', + 'operations': ['invoke', 'stream'], + }, + ] + descriptor = self._make_descriptor() + + with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store: + mock_store = AsyncMock() + mock_store.build_snapshot_from_event = AsyncMock(return_value={ + 'conversation': {}, + 'actor': {}, + 'subject': {}, + 'runner': {}, + }) + mock_get_store.return_value = mock_store + + context_dict = await builder.build_context_from_event( + event=event, + binding=binding, + descriptor=descriptor, + resources=resources, + ) + + assert context_dict['runtime']['metadata']['model_context_window_tokens'] == 128000 + mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('llm-model') + + @pytest.mark.asyncio + async def test_model_context_window_uses_primary_llm_only(self): + """Fallback model windows should not replace missing primary model metadata.""" + mock_app = self._make_mock_app() + mock_app.model_mgr = MagicMock() + mock_app.model_mgr.get_model_by_uuid = AsyncMock( + return_value=SimpleNamespace( + model_entity=SimpleNamespace(context_length=None), + ) + ) + builder = AgentRunContextBuilder(mock_app) + resources = self._make_resources() + resources['models'] = [ + { + 'model_id': 'primary-model', + 'model_type': 'llm', + 'provider': 'test-provider', + 'operations': ['invoke', 'stream'], + }, + { + 'model_id': 'fallback-model', + 'model_type': 'llm', + 'provider': 'test-provider', + 'operations': ['invoke', 'stream'], + }, + ] + + assert await builder._build_model_context_window_tokens(resources) is None + mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('primary-model') + @pytest.mark.asyncio async def test_build_context_preserves_subject_data_for_non_message_events(self): """Non-message EBA events keep subject.data instead of relying on message text.""" diff --git a/tests/unit_tests/plugin/test_handler_actions.py b/tests/unit_tests/plugin/test_handler_actions.py index 68f2180f..7e1c0ff6 100644 --- a/tests/unit_tests/plugin/test_handler_actions.py +++ b/tests/unit_tests/plugin/test_handler_actions.py @@ -388,6 +388,7 @@ class TestAgentRunProxyActions: def query(remove_think=True): return SimpleNamespace( pipeline_config={'output': {'misc': {'remove-think': remove_think}}}, + variables={}, prompt=SimpleNamespace( messages=[provider_message.Message(role='system', content='effective prompt')] ), @@ -488,6 +489,60 @@ class TestAgentRunProxyActions: assert kwargs['remove_think'] is True assert [tool.name for tool in kwargs['funcs']] == ['search'] + @pytest.mark.asyncio + async def test_invoke_llm_returns_provider_usage(self, app): + """INVOKE_LLM includes optional provider usage in the action response.""" + from langbot.pkg.agent.runner.session_registry import get_session_registry + from langbot.pkg.provider.modelmgr import requester as model_requester + + usage = { + 'prompt_tokens': 11, + 'completion_tokens': 7, + 'total_tokens': 18, + 'prompt_tokens_details': {'cached_tokens': 3}, + } + + class UsageProvider: + async def invoke_llm(self, **kwargs): + kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage + return provider_message.Message(role='assistant', content='ok') + + run_id = 'run_proxy_invoke_llm_usage' + query = self.query() + app.query_pool.cached_queries[905] = query + + registry = get_session_registry() + await registry.unregister(run_id) + await registry.register( + run_id=run_id, + runner_id='plugin:test/runner/default', + query_id=905, + plugin_identity='test/runner', + resources=make_agent_resources(models=[{'model_id': 'llm_usage_001'}]), + ) + + model = SimpleNamespace( + model_entity=SimpleNamespace(abilities=[], extra_args={}), + provider=UsageProvider(), + ) + app.model_mgr.get_model_by_uuid.return_value = model + runtime_handler = make_handler(app) + + try: + response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM.value]({ + 'run_id': run_id, + 'caller_plugin_identity': 'test/runner', + 'llm_model_uuid': 'llm_usage_001', + 'messages': [{'role': 'user', 'content': 'hello'}], + }) + finally: + await registry.unregister(run_id) + + assert response.code == 0 + assert response.data['message']['content'] == 'ok' + assert response.data['usage'] == usage + assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables + @pytest.mark.asyncio async def test_invoke_llm_stream_restores_query_and_options(self, app): """INVOKE_LLM_STREAM applies the same host context as non-streaming calls.""" @@ -598,6 +653,63 @@ class TestAgentRunProxyActions: assert [response.code for response in responses] == [0, 0] assert [response.data['chunk']['content'] for response in responses] == ['ok', ' done'] + @pytest.mark.asyncio + async def test_invoke_llm_stream_returns_provider_usage_event(self, app): + """INVOKE_LLM_STREAM emits a final usage-only action response when available.""" + from langbot.pkg.agent.runner.session_registry import get_session_registry + from langbot.pkg.provider.modelmgr import requester as model_requester + + usage = { + 'prompt_tokens': 9, + 'completion_tokens': 4, + 'total_tokens': 13, + 'prompt_tokens_details': {'cached_tokens': 2}, + } + + class StreamProvider: + async def invoke_llm_stream(self, **kwargs): + yield provider_message.MessageChunk(role='assistant', content='ok') + kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage + + run_id = 'run_proxy_invoke_llm_stream_usage' + query = self.query() + app.query_pool.cached_queries[906] = query + + registry = get_session_registry() + await registry.unregister(run_id) + await registry.register( + run_id=run_id, + runner_id='plugin:test/runner/default', + query_id=906, + plugin_identity='test/runner', + resources=make_agent_resources(models=[{'model_id': 'llm_stream_usage_001'}]), + ) + + model = SimpleNamespace( + model_entity=SimpleNamespace(abilities=[], extra_args={}), + provider=StreamProvider(), + ) + app.model_mgr.get_model_by_uuid.return_value = model + runtime_handler = make_handler(app) + + responses = [] + try: + stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({ + 'run_id': run_id, + 'caller_plugin_identity': 'test/runner', + 'llm_model_uuid': 'llm_stream_usage_001', + 'messages': [{'role': 'user', 'content': 'hello'}], + }) + async for response in stream: + responses.append(response) + finally: + await registry.unregister(run_id) + + assert [response.code for response in responses] == [0, 0] + assert responses[0].data['chunk']['content'] == 'ok' + assert responses[1].data == {'usage': usage} + assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables + @pytest.mark.asyncio async def test_call_tool_passes_current_query(self, app): """CALL_TOOL passes the current Query back into tool execution.""" diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index 1ec12d82..12bc6dbb 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -115,6 +115,15 @@ class TestExtractUsage: assert result['prompt_tokens'] == 0 assert result['completion_tokens'] == 0 + def test_extract_usage_without_provider_usage(self): + """Missing provider usage is not treated as authoritative zero usage.""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + response = Mock() + response.usage = None + + assert requester._extract_usage(response) is None + class TestNormalizeUsage: """Test _normalize_usage helper covering real-world usage shapes""" @@ -131,6 +140,22 @@ class TestNormalizeUsage: ) assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20} + def test_preserves_token_details(self): + """Provider token details such as cache counters are preserved.""" + result = litellmchat.LiteLLMRequester._normalize_usage( + { + 'prompt_tokens': 12, + 'completion_tokens': 8, + 'total_tokens': 20, + 'prompt_tokens_details': {'cached_tokens': 7}, + 'completion_tokens_details': {'reasoning_tokens': 3}, + } + ) + + assert result['prompt_tokens'] == 12 + assert result['prompt_tokens_details'] == {'cached_tokens': 7} + assert result['completion_tokens_details'] == {'reasoning_tokens': 3} + def test_missing_total_is_derived(self): """When total_tokens is absent/zero it is derived from prompt + completion""" usage = Mock() diff --git a/tests/unit_tests/provider/test_requester_base.py b/tests/unit_tests/provider/test_requester_base.py index c34556cd..0026f388 100644 --- a/tests/unit_tests/provider/test_requester_base.py +++ b/tests/unit_tests/provider/test_requester_base.py @@ -299,6 +299,59 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l assert result.role == 'assistant' +@pytest.mark.asyncio +async def test_runtime_provider_invoke_llm_stashes_usage(runtime_provider, runtime_llm_model): + """RuntimeProvider preserves requester usage for upstream action handlers.""" + provider = runtime_provider + + import langbot_plugin.api.entities.builtin.provider.message as provider_message + import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + + query = pipeline_query.Query.model_construct( + query_id='test-query-usage', + launcher_type='person', + launcher_id=12345, + sender_id=12345, + message_chain=None, + message_event=None, + adapter=None, + pipeline_uuid='pipeline-uuid', + bot_uuid='bot-uuid', + pipeline_config={'ai': {}, 'output': {}, 'trigger': {}}, + session=None, + prompt=None, + messages=[], + user_message=None, + use_funcs=[], + use_llm_model_uuid=None, + variables={}, + resp_messages=[], + resp_message_chain=None, + current_stage_name=None, + ) + usage = { + 'prompt_tokens': 11, + 'completion_tokens': 7, + 'total_tokens': 18, + 'prompt_tokens_details': {'cached_tokens': 3}, + } + provider.requester.invoke_llm = AsyncMock( + return_value=( + provider_message.Message(role='assistant', content='ok'), + usage, + ) + ) + + result = await provider.invoke_llm( + query, + runtime_llm_model, + [provider_message.Message(role='user', content='Hello')], + ) + + assert result.content == 'ok' + assert query.variables[requester.LLM_USAGE_QUERY_VARIABLE] == usage + + @pytest.mark.asyncio async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider, runtime_llm_model): """Test RuntimeProvider.invoke_llm_stream yields chunks from requester.""" @@ -340,6 +393,62 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider assert chunks[0].role == 'assistant' +@pytest.mark.asyncio +async def test_runtime_provider_invoke_llm_stream_stashes_usage(runtime_provider, runtime_llm_model): + """RuntimeProvider transfers captured stream usage to the public query usage key.""" + provider = runtime_provider + + import langbot_plugin.api.entities.builtin.provider.message as provider_message + import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + + query = pipeline_query.Query.model_construct( + query_id='test-stream-usage', + launcher_type='person', + launcher_id=12345, + sender_id=12345, + message_chain=None, + message_event=None, + adapter=None, + pipeline_uuid='pipeline-uuid', + bot_uuid='bot-uuid', + pipeline_config={'ai': {}, 'output': {}, 'trigger': {}}, + session=None, + prompt=None, + messages=[], + user_message=None, + use_funcs=[], + use_llm_model_uuid=None, + variables={}, + resp_messages=[], + resp_message_chain=None, + current_stage_name=None, + ) + usage = { + 'prompt_tokens': 13, + 'completion_tokens': 2, + 'total_tokens': 15, + } + + async def fake_stream(**kwargs): + kwargs['query'].variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage + yield provider_message.MessageChunk(role='assistant', content='ok') + + provider.requester.invoke_llm_stream = fake_stream + + chunks = [ + chunk + async for chunk in provider.invoke_llm_stream( + query, + runtime_llm_model, + [provider_message.Message(role='user', content='Hello')], + ) + ] + + assert len(chunks) == 1 + assert query.variables[requester.LLM_USAGE_QUERY_VARIABLE] == usage + assert requester.STREAM_USAGE_QUERY_VARIABLE not in query.variables + + @pytest.mark.asyncio async def test_runtime_provider_invoke_embedding_returns_vectors(runtime_provider, runtime_embedding_model): """Test RuntimeProvider.invoke_embedding returns embedding vectors."""