mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-14 01:36:03 +00:00
Propagate agent runner model usage context
This commit is contained in:
@@ -475,8 +475,11 @@ Host 必须校验 `state.updated` 的 scope、key、value 大小和 JSON 可序
|
||||
```python
|
||||
# Model
|
||||
await api.invoke_llm(llm_model_uuid, messages, funcs=None, extra_args=None)
|
||||
await api.invoke_llm_with_usage(llm_model_uuid, messages, funcs=None, extra_args=None)
|
||||
async for chunk in api.invoke_llm_stream(llm_model_uuid, messages, funcs=None, extra_args=None):
|
||||
...
|
||||
async for event in api.invoke_llm_stream_events(llm_model_uuid, messages, funcs=None, extra_args=None):
|
||||
...
|
||||
await api.invoke_rerank(rerank_model_id, query, documents, top_k=None)
|
||||
|
||||
# Tool
|
||||
@@ -519,6 +522,16 @@ await api.get_langbot_version()
|
||||
`llm_model_uuid`,wire payload 字段也是 `llm_model_uuid`。该值对 runner
|
||||
仍是 opaque identifier,不应解析其内部格式。
|
||||
|
||||
`invoke_llm()` 和 `invoke_llm_stream()` 保持兼容:前者返回 `Message`,后者只
|
||||
yield `MessageChunk`。需要 provider 真实 token 计量的 runner 应使用
|
||||
`invoke_llm_with_usage()` 或 `invoke_llm_stream_events()`。Host response 可在
|
||||
原有 `{message: ...}` / `{chunk: ...}` 外额外携带可选 `usage` 字段;streaming
|
||||
场景允许在所有 chunk 之后追加一个 usage-only event。`usage` 至少保留
|
||||
OpenAI-compatible 的 `prompt_tokens`、`completion_tokens`、`total_tokens`,
|
||||
若 provider 返回 `prompt_tokens_details` / `completion_tokens_details` 或
|
||||
cache token counters,Host / SDK 不应丢弃这些字段。没有 usage 的 provider
|
||||
必须继续返回成功响应,SDK 将 usage 置为 `None`。
|
||||
|
||||
`get_prompt()` 返回当前 query-backed run 的 Host effective prompt messages:
|
||||
`list[Message]` 的 JSON 形式。该能力只在 `ctx.context.available_apis.prompt_get`
|
||||
为 true 时可用;没有 query 缓存、prompt 已过期或非 query entry run 时 Host
|
||||
|
||||
@@ -179,6 +179,52 @@ class AgentRunContextBuilder:
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
@staticmethod
|
||||
def _positive_int(value: typing.Any) -> int | None:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if isinstance(value, int) and value > 0:
|
||||
return value
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
parsed_value = int(value)
|
||||
if parsed_value > 0:
|
||||
return parsed_value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_llm_model_resource(model_resource: ModelResource) -> bool:
|
||||
operations = model_resource.get('operations')
|
||||
if isinstance(operations, list) and operations:
|
||||
return bool({'invoke', 'stream'} & {str(operation) for operation in operations})
|
||||
return model_resource.get('model_type') != 'rerank'
|
||||
|
||||
async def _build_model_context_window_tokens(self, resources: AgentResources) -> int | None:
|
||||
model_mgr = getattr(self.ap, 'model_mgr', None)
|
||||
if model_mgr is None:
|
||||
return None
|
||||
|
||||
for model_resource in resources.get('models', []):
|
||||
if not self._is_llm_model_resource(model_resource):
|
||||
continue
|
||||
|
||||
model_uuid = model_resource.get('model_id')
|
||||
if not isinstance(model_uuid, str) or not model_uuid:
|
||||
continue
|
||||
|
||||
try:
|
||||
model = await model_mgr.get_model_by_uuid(model_uuid)
|
||||
except Exception as exc:
|
||||
logger = getattr(self.ap, 'logger', None)
|
||||
if logger is not None:
|
||||
logger.debug(f'Failed to resolve model context window for {model_uuid}: {exc}')
|
||||
continue
|
||||
|
||||
model_entity = getattr(model, 'model_entity', None)
|
||||
context_length = self._positive_int(getattr(model_entity, 'context_length', None))
|
||||
return context_length
|
||||
|
||||
return None
|
||||
|
||||
async def build_context_from_event(
|
||||
self,
|
||||
event: AgentEventEnvelope,
|
||||
@@ -270,6 +316,8 @@ class AgentRunContextBuilder:
|
||||
persistent_state_store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine())
|
||||
state: AgentRunState = await persistent_state_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
|
||||
model_context_window_tokens = await self._build_model_context_window_tokens(resources)
|
||||
|
||||
# Build runtime context
|
||||
runtime: AgentRuntimeContext = {
|
||||
'langbot_version': self.ap.ver_mgr.get_current_version(),
|
||||
@@ -279,10 +327,7 @@ class AgentRunContextBuilder:
|
||||
'bot_id': event.bot_id,
|
||||
'workspace_id': event.workspace_id,
|
||||
'streaming_supported': event.delivery.supports_streaming,
|
||||
'model_context_window_tokens': None,
|
||||
# TODO(model-info): populate model_context_window_tokens after
|
||||
# LiteLLM/model metadata lands. Runners fall back to their
|
||||
# ctx.config until Host can provide the real window.
|
||||
'model_context_window_tokens': model_context_window_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
|
||||
from ..entity.persistence import plugin as persistence_plugin
|
||||
from ..entity.persistence import bstorage as persistence_bstorage
|
||||
from ..provider.modelmgr import requester as model_requester
|
||||
|
||||
from ..core import app
|
||||
from ..utils import constants
|
||||
@@ -43,6 +44,18 @@ def _make_rag_error_response(error: Exception, error_type: str, **extra_context)
|
||||
return handler.ActionResponse.error(message=message)
|
||||
|
||||
|
||||
def _pop_query_llm_usage(query: Any) -> dict[str, Any] | None:
|
||||
"""Read provider usage stashed on a query by RuntimeProvider."""
|
||||
if query is None or not getattr(query, 'variables', None):
|
||||
return None
|
||||
usage = query.variables.pop(model_requester.LLM_USAGE_QUERY_VARIABLE, None)
|
||||
if usage is None:
|
||||
return None
|
||||
if isinstance(usage, dict):
|
||||
return dict(usage)
|
||||
return None
|
||||
|
||||
|
||||
def _i18n_to_dict(value: Any) -> dict[str, Any]:
|
||||
"""Convert SDK i18n values to plain dictionaries."""
|
||||
if value is None:
|
||||
@@ -802,10 +815,20 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
remove_think=remove_think,
|
||||
)
|
||||
|
||||
usage = None
|
||||
if isinstance(result, tuple):
|
||||
result, usage = result
|
||||
if usage is None:
|
||||
usage = _pop_query_llm_usage(query)
|
||||
|
||||
response_data = {
|
||||
'message': result.model_dump(),
|
||||
}
|
||||
if usage is not None:
|
||||
response_data['usage'] = usage
|
||||
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'message': result.model_dump(),
|
||||
},
|
||||
data=response_data,
|
||||
)
|
||||
|
||||
@self.action(PluginToRuntimeAction.INVOKE_LLM_STREAM)
|
||||
@@ -867,6 +890,13 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
'chunk': chunk.model_dump(),
|
||||
},
|
||||
)
|
||||
usage = _pop_query_llm_usage(query)
|
||||
if usage is not None:
|
||||
yield handler.ActionResponse.success(
|
||||
data={
|
||||
'usage': usage,
|
||||
},
|
||||
)
|
||||
|
||||
@self.action(PluginToRuntimeAction.CALL_TOOL)
|
||||
async def call_tool(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
|
||||
@@ -12,6 +12,19 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
LLM_USAGE_QUERY_VARIABLE = '_llm_usage'
|
||||
STREAM_USAGE_QUERY_VARIABLE = '_stream_usage'
|
||||
|
||||
|
||||
def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None:
|
||||
"""Store the latest provider usage on the query for upstream action handlers."""
|
||||
if query is None or not usage_info:
|
||||
return
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
query.variables[LLM_USAGE_QUERY_VARIABLE] = dict(usage_info)
|
||||
|
||||
|
||||
class RuntimeProvider:
|
||||
"""运行时模型提供商"""
|
||||
|
||||
@@ -67,6 +80,7 @@ class RuntimeProvider:
|
||||
if isinstance(result, tuple):
|
||||
msg, usage_info = result
|
||||
if usage_info:
|
||||
_store_llm_usage(query, usage_info)
|
||||
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||
output_tokens = usage_info.get('completion_tokens', 0)
|
||||
return msg
|
||||
@@ -146,11 +160,12 @@ class RuntimeProvider:
|
||||
if query:
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
if '_stream_usage' in query.variables:
|
||||
usage_info = query.variables['_stream_usage']
|
||||
if STREAM_USAGE_QUERY_VARIABLE in query.variables:
|
||||
usage_info = query.variables[STREAM_USAGE_QUERY_VARIABLE]
|
||||
_store_llm_usage(query, usage_info)
|
||||
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||
output_tokens = usage_info.get('completion_tokens', 0)
|
||||
del query.variables['_stream_usage']
|
||||
del query.variables[STREAM_USAGE_QUERY_VARIABLE]
|
||||
except Exception as e:
|
||||
status = 'error'
|
||||
error_message = str(e)
|
||||
|
||||
@@ -250,32 +250,81 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
- dict with the same keys
|
||||
- missing ``total_tokens`` (derived from prompt + completion)
|
||||
- ``None`` / partially-populated usage (defaults to 0)
|
||||
- provider-specific token details, including cache token counters
|
||||
"""
|
||||
if usage is None:
|
||||
return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
|
||||
def _plain_value(value: typing.Any) -> typing.Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, dict):
|
||||
return {k: _plain_value(v) for k, v in value.items() if v is not None}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_plain_value(v) for v in value]
|
||||
|
||||
def _get(key: str) -> typing.Any:
|
||||
if isinstance(usage, dict):
|
||||
return usage.get(key)
|
||||
return getattr(usage, key, None)
|
||||
model_dump = getattr(value, 'model_dump', None)
|
||||
if callable(model_dump):
|
||||
try:
|
||||
dumped = model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return _plain_value(dumped)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
prompt_tokens = _get('prompt_tokens') or 0
|
||||
completion_tokens = _get('completion_tokens') or 0
|
||||
total_tokens = _get('total_tokens') or 0
|
||||
return value
|
||||
|
||||
def _usage_dict(value: typing.Any) -> dict[str, typing.Any]:
|
||||
if value is None:
|
||||
return {}
|
||||
plain = _plain_value(value)
|
||||
if isinstance(plain, dict):
|
||||
return plain
|
||||
|
||||
def _is_mock_attr(attr: typing.Any) -> bool:
|
||||
return type(attr).__module__.startswith('unittest.mock')
|
||||
|
||||
data: dict[str, typing.Any] = {}
|
||||
for key in (
|
||||
'prompt_tokens',
|
||||
'completion_tokens',
|
||||
'total_tokens',
|
||||
'prompt_tokens_details',
|
||||
'completion_tokens_details',
|
||||
'cache_creation_input_tokens',
|
||||
'cache_read_input_tokens',
|
||||
'input_token_details',
|
||||
'output_token_details',
|
||||
):
|
||||
attr_value = getattr(value, key, None)
|
||||
if attr_value is not None and not _is_mock_attr(attr_value):
|
||||
data[key] = _plain_value(attr_value)
|
||||
return data
|
||||
|
||||
def _to_int(value: typing.Any) -> int:
|
||||
try:
|
||||
return int(value or 0)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
normalized = _usage_dict(usage)
|
||||
|
||||
prompt_tokens = _to_int(normalized.get('prompt_tokens'))
|
||||
completion_tokens = _to_int(normalized.get('completion_tokens'))
|
||||
total_tokens = _to_int(normalized.get('total_tokens'))
|
||||
|
||||
# Some providers omit total_tokens in streaming usage; derive it.
|
||||
if not total_tokens:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return {
|
||||
'prompt_tokens': int(prompt_tokens),
|
||||
'completion_tokens': int(completion_tokens),
|
||||
'total_tokens': int(total_tokens),
|
||||
}
|
||||
normalized['prompt_tokens'] = prompt_tokens
|
||||
normalized['completion_tokens'] = completion_tokens
|
||||
normalized['total_tokens'] = total_tokens
|
||||
return normalized
|
||||
|
||||
def _extract_usage(self, response) -> dict:
|
||||
def _extract_usage(self, response) -> dict | None:
|
||||
"""Extract usage info from a non-streaming LiteLLM response."""
|
||||
return self._normalize_usage(getattr(response, 'usage', None))
|
||||
usage = getattr(response, 'usage', None)
|
||||
if usage is None:
|
||||
return None
|
||||
return self._normalize_usage(usage)
|
||||
|
||||
@staticmethod
|
||||
def _as_dict(value: typing.Any) -> dict:
|
||||
@@ -474,7 +523,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
if query is not None:
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
query.variables['_stream_usage'] = usage_info
|
||||
query.variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage_info
|
||||
|
||||
if not hasattr(chunk, 'choices') or not chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
# SDK imports for validation
|
||||
@@ -174,6 +175,87 @@ class TestContextValidation:
|
||||
# Verify input
|
||||
assert validated.input.text == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_populates_model_context_window(self):
|
||||
"""Runtime metadata should expose the selected LLM model context window."""
|
||||
mock_app = self._make_mock_app()
|
||||
mock_app.model_mgr = MagicMock()
|
||||
mock_app.model_mgr.get_model_by_uuid = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
model_entity=SimpleNamespace(context_length=128000),
|
||||
)
|
||||
)
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
event = self._make_event_envelope()
|
||||
binding = self._make_binding()
|
||||
resources = self._make_resources()
|
||||
resources['models'] = [
|
||||
{
|
||||
'model_id': 'rerank-model',
|
||||
'model_type': 'rerank',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['rerank'],
|
||||
},
|
||||
{
|
||||
'model_id': 'llm-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
]
|
||||
descriptor = self._make_descriptor()
|
||||
|
||||
with patch('langbot.pkg.agent.runner.context_builder.get_persistent_state_store') as mock_get_store:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.build_snapshot_from_event = AsyncMock(return_value={
|
||||
'conversation': {},
|
||||
'actor': {},
|
||||
'subject': {},
|
||||
'runner': {},
|
||||
})
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
context_dict = await builder.build_context_from_event(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
assert context_dict['runtime']['metadata']['model_context_window_tokens'] == 128000
|
||||
mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('llm-model')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_context_window_uses_primary_llm_only(self):
|
||||
"""Fallback model windows should not replace missing primary model metadata."""
|
||||
mock_app = self._make_mock_app()
|
||||
mock_app.model_mgr = MagicMock()
|
||||
mock_app.model_mgr.get_model_by_uuid = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
model_entity=SimpleNamespace(context_length=None),
|
||||
)
|
||||
)
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
resources = self._make_resources()
|
||||
resources['models'] = [
|
||||
{
|
||||
'model_id': 'primary-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
{
|
||||
'model_id': 'fallback-model',
|
||||
'model_type': 'llm',
|
||||
'provider': 'test-provider',
|
||||
'operations': ['invoke', 'stream'],
|
||||
},
|
||||
]
|
||||
|
||||
assert await builder._build_model_context_window_tokens(resources) is None
|
||||
mock_app.model_mgr.get_model_by_uuid.assert_awaited_once_with('primary-model')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_preserves_subject_data_for_non_message_events(self):
|
||||
"""Non-message EBA events keep subject.data instead of relying on message text."""
|
||||
|
||||
@@ -388,6 +388,7 @@ class TestAgentRunProxyActions:
|
||||
def query(remove_think=True):
|
||||
return SimpleNamespace(
|
||||
pipeline_config={'output': {'misc': {'remove-think': remove_think}}},
|
||||
variables={},
|
||||
prompt=SimpleNamespace(
|
||||
messages=[provider_message.Message(role='system', content='effective prompt')]
|
||||
),
|
||||
@@ -488,6 +489,60 @@ class TestAgentRunProxyActions:
|
||||
assert kwargs['remove_think'] is True
|
||||
assert [tool.name for tool in kwargs['funcs']] == ['search']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_returns_provider_usage(self, app):
|
||||
"""INVOKE_LLM includes optional provider usage in the action response."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
from langbot.pkg.provider.modelmgr import requester as model_requester
|
||||
|
||||
usage = {
|
||||
'prompt_tokens': 11,
|
||||
'completion_tokens': 7,
|
||||
'total_tokens': 18,
|
||||
'prompt_tokens_details': {'cached_tokens': 3},
|
||||
}
|
||||
|
||||
class UsageProvider:
|
||||
async def invoke_llm(self, **kwargs):
|
||||
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
|
||||
return provider_message.Message(role='assistant', content='ok')
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_usage'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[905] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=905,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_usage_001'}]),
|
||||
)
|
||||
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={}),
|
||||
provider=UsageProvider(),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_usage_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data['message']['content'] == 'ok'
|
||||
assert response.data['usage'] == usage
|
||||
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_restores_query_and_options(self, app):
|
||||
"""INVOKE_LLM_STREAM applies the same host context as non-streaming calls."""
|
||||
@@ -598,6 +653,63 @@ class TestAgentRunProxyActions:
|
||||
assert [response.code for response in responses] == [0, 0]
|
||||
assert [response.data['chunk']['content'] for response in responses] == ['ok', ' done']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_returns_provider_usage_event(self, app):
|
||||
"""INVOKE_LLM_STREAM emits a final usage-only action response when available."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
from langbot.pkg.provider.modelmgr import requester as model_requester
|
||||
|
||||
usage = {
|
||||
'prompt_tokens': 9,
|
||||
'completion_tokens': 4,
|
||||
'total_tokens': 13,
|
||||
'prompt_tokens_details': {'cached_tokens': 2},
|
||||
}
|
||||
|
||||
class StreamProvider:
|
||||
async def invoke_llm_stream(self, **kwargs):
|
||||
yield provider_message.MessageChunk(role='assistant', content='ok')
|
||||
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_stream_usage'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[906] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=906,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_stream_usage_001'}]),
|
||||
)
|
||||
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={}),
|
||||
provider=StreamProvider(),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
responses = []
|
||||
try:
|
||||
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_stream_usage_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
async for response in stream:
|
||||
responses.append(response)
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert [response.code for response in responses] == [0, 0]
|
||||
assert responses[0].data['chunk']['content'] == 'ok'
|
||||
assert responses[1].data == {'usage': usage}
|
||||
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_passes_current_query(self, app):
|
||||
"""CALL_TOOL passes the current Query back into tool execution."""
|
||||
|
||||
@@ -115,6 +115,15 @@ class TestExtractUsage:
|
||||
assert result['prompt_tokens'] == 0
|
||||
assert result['completion_tokens'] == 0
|
||||
|
||||
def test_extract_usage_without_provider_usage(self):
|
||||
"""Missing provider usage is not treated as authoritative zero usage."""
|
||||
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||
|
||||
response = Mock()
|
||||
response.usage = None
|
||||
|
||||
assert requester._extract_usage(response) is None
|
||||
|
||||
|
||||
class TestNormalizeUsage:
|
||||
"""Test _normalize_usage helper covering real-world usage shapes"""
|
||||
@@ -131,6 +140,22 @@ class TestNormalizeUsage:
|
||||
)
|
||||
assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20}
|
||||
|
||||
def test_preserves_token_details(self):
|
||||
"""Provider token details such as cache counters are preserved."""
|
||||
result = litellmchat.LiteLLMRequester._normalize_usage(
|
||||
{
|
||||
'prompt_tokens': 12,
|
||||
'completion_tokens': 8,
|
||||
'total_tokens': 20,
|
||||
'prompt_tokens_details': {'cached_tokens': 7},
|
||||
'completion_tokens_details': {'reasoning_tokens': 3},
|
||||
}
|
||||
)
|
||||
|
||||
assert result['prompt_tokens'] == 12
|
||||
assert result['prompt_tokens_details'] == {'cached_tokens': 7}
|
||||
assert result['completion_tokens_details'] == {'reasoning_tokens': 3}
|
||||
|
||||
def test_missing_total_is_derived(self):
|
||||
"""When total_tokens is absent/zero it is derived from prompt + completion"""
|
||||
usage = Mock()
|
||||
|
||||
@@ -299,6 +299,59 @@ async def test_runtime_provider_invoke_llm_delegates(runtime_provider, runtime_l
|
||||
assert result.role == 'assistant'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_provider_invoke_llm_stashes_usage(runtime_provider, runtime_llm_model):
|
||||
"""RuntimeProvider preserves requester usage for upstream action handlers."""
|
||||
provider = runtime_provider
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
query = pipeline_query.Query.model_construct(
|
||||
query_id='test-query-usage',
|
||||
launcher_type='person',
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_chain=None,
|
||||
message_event=None,
|
||||
adapter=None,
|
||||
pipeline_uuid='pipeline-uuid',
|
||||
bot_uuid='bot-uuid',
|
||||
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
|
||||
session=None,
|
||||
prompt=None,
|
||||
messages=[],
|
||||
user_message=None,
|
||||
use_funcs=[],
|
||||
use_llm_model_uuid=None,
|
||||
variables={},
|
||||
resp_messages=[],
|
||||
resp_message_chain=None,
|
||||
current_stage_name=None,
|
||||
)
|
||||
usage = {
|
||||
'prompt_tokens': 11,
|
||||
'completion_tokens': 7,
|
||||
'total_tokens': 18,
|
||||
'prompt_tokens_details': {'cached_tokens': 3},
|
||||
}
|
||||
provider.requester.invoke_llm = AsyncMock(
|
||||
return_value=(
|
||||
provider_message.Message(role='assistant', content='ok'),
|
||||
usage,
|
||||
)
|
||||
)
|
||||
|
||||
result = await provider.invoke_llm(
|
||||
query,
|
||||
runtime_llm_model,
|
||||
[provider_message.Message(role='user', content='Hello')],
|
||||
)
|
||||
|
||||
assert result.content == 'ok'
|
||||
assert query.variables[requester.LLM_USAGE_QUERY_VARIABLE] == usage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider, runtime_llm_model):
|
||||
"""Test RuntimeProvider.invoke_llm_stream yields chunks from requester."""
|
||||
@@ -340,6 +393,62 @@ async def test_runtime_provider_invoke_llm_stream_yields_chunks(runtime_provider
|
||||
assert chunks[0].role == 'assistant'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_provider_invoke_llm_stream_stashes_usage(runtime_provider, runtime_llm_model):
|
||||
"""RuntimeProvider transfers captured stream usage to the public query usage key."""
|
||||
provider = runtime_provider
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
query = pipeline_query.Query.model_construct(
|
||||
query_id='test-stream-usage',
|
||||
launcher_type='person',
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_chain=None,
|
||||
message_event=None,
|
||||
adapter=None,
|
||||
pipeline_uuid='pipeline-uuid',
|
||||
bot_uuid='bot-uuid',
|
||||
pipeline_config={'ai': {}, 'output': {}, 'trigger': {}},
|
||||
session=None,
|
||||
prompt=None,
|
||||
messages=[],
|
||||
user_message=None,
|
||||
use_funcs=[],
|
||||
use_llm_model_uuid=None,
|
||||
variables={},
|
||||
resp_messages=[],
|
||||
resp_message_chain=None,
|
||||
current_stage_name=None,
|
||||
)
|
||||
usage = {
|
||||
'prompt_tokens': 13,
|
||||
'completion_tokens': 2,
|
||||
'total_tokens': 15,
|
||||
}
|
||||
|
||||
async def fake_stream(**kwargs):
|
||||
kwargs['query'].variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage
|
||||
yield provider_message.MessageChunk(role='assistant', content='ok')
|
||||
|
||||
provider.requester.invoke_llm_stream = fake_stream
|
||||
|
||||
chunks = [
|
||||
chunk
|
||||
async for chunk in provider.invoke_llm_stream(
|
||||
query,
|
||||
runtime_llm_model,
|
||||
[provider_message.Message(role='user', content='Hello')],
|
||||
)
|
||||
]
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert query.variables[requester.LLM_USAGE_QUERY_VARIABLE] == usage
|
||||
assert requester.STREAM_USAGE_QUERY_VARIABLE not in query.variables
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_provider_invoke_embedding_returns_vectors(runtime_provider, runtime_embedding_model):
|
||||
"""Test RuntimeProvider.invoke_embedding returns embedding vectors."""
|
||||
|
||||
Reference in New Issue
Block a user