Propagate agent runner model usage context

This commit is contained in:
huanghuoguoguo
2026-06-14 07:41:57 +08:00
parent 1153433693
commit 09adf4c541
9 changed files with 507 additions and 27 deletions
@@ -179,6 +179,52 @@ class AgentRunContextBuilder:
def __init__(self, ap: app.Application):
self.ap = ap
@staticmethod
def _positive_int(value: typing.Any) -> int | None:
if isinstance(value, bool):
return None
if isinstance(value, int) and value > 0:
return value
if isinstance(value, str) and value.isdigit():
parsed_value = int(value)
if parsed_value > 0:
return parsed_value
return None
@staticmethod
def _is_llm_model_resource(model_resource: ModelResource) -> bool:
operations = model_resource.get('operations')
if isinstance(operations, list) and operations:
return bool({'invoke', 'stream'} & {str(operation) for operation in operations})
return model_resource.get('model_type') != 'rerank'
async def _build_model_context_window_tokens(self, resources: AgentResources) -> int | None:
model_mgr = getattr(self.ap, 'model_mgr', None)
if model_mgr is None:
return None
for model_resource in resources.get('models', []):
if not self._is_llm_model_resource(model_resource):
continue
model_uuid = model_resource.get('model_id')
if not isinstance(model_uuid, str) or not model_uuid:
continue
try:
model = await model_mgr.get_model_by_uuid(model_uuid)
except Exception as exc:
logger = getattr(self.ap, 'logger', None)
if logger is not None:
logger.debug(f'Failed to resolve model context window for {model_uuid}: {exc}')
continue
model_entity = getattr(model, 'model_entity', None)
context_length = self._positive_int(getattr(model_entity, 'context_length', None))
return context_length
return None
async def build_context_from_event(
self,
event: AgentEventEnvelope,
@@ -270,6 +316,8 @@ class AgentRunContextBuilder:
persistent_state_store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine())
state: AgentRunState = await persistent_state_store.build_snapshot_from_event(event, binding, descriptor)
model_context_window_tokens = await self._build_model_context_window_tokens(resources)
# Build runtime context
runtime: AgentRuntimeContext = {
'langbot_version': self.ap.ver_mgr.get_current_version(),
@@ -279,10 +327,7 @@ class AgentRunContextBuilder:
'bot_id': event.bot_id,
'workspace_id': event.workspace_id,
'streaming_supported': event.delivery.supports_streaming,
'model_context_window_tokens': None,
# TODO(model-info): populate model_context_window_tokens after
# LiteLLM/model metadata lands. Runners fall back to their
# ctx.config until Host can provide the real window.
'model_context_window_tokens': model_context_window_tokens,
},
}
+33 -3
View File
@@ -21,6 +21,7 @@ import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from ..entity.persistence import plugin as persistence_plugin
from ..entity.persistence import bstorage as persistence_bstorage
from ..provider.modelmgr import requester as model_requester
from ..core import app
from ..utils import constants
@@ -43,6 +44,18 @@ def _make_rag_error_response(error: Exception, error_type: str, **extra_context)
return handler.ActionResponse.error(message=message)
def _pop_query_llm_usage(query: Any) -> dict[str, Any] | None:
"""Read provider usage stashed on a query by RuntimeProvider."""
if query is None or not getattr(query, 'variables', None):
return None
usage = query.variables.pop(model_requester.LLM_USAGE_QUERY_VARIABLE, None)
if usage is None:
return None
if isinstance(usage, dict):
return dict(usage)
return None
def _i18n_to_dict(value: Any) -> dict[str, Any]:
"""Convert SDK i18n values to plain dictionaries."""
if value is None:
@@ -802,10 +815,20 @@ class RuntimeConnectionHandler(handler.Handler):
remove_think=remove_think,
)
usage = None
if isinstance(result, tuple):
result, usage = result
if usage is None:
usage = _pop_query_llm_usage(query)
response_data = {
'message': result.model_dump(),
}
if usage is not None:
response_data['usage'] = usage
return handler.ActionResponse.success(
data={
'message': result.model_dump(),
},
data=response_data,
)
@self.action(PluginToRuntimeAction.INVOKE_LLM_STREAM)
@@ -867,6 +890,13 @@ class RuntimeConnectionHandler(handler.Handler):
'chunk': chunk.model_dump(),
},
)
usage = _pop_query_llm_usage(query)
if usage is not None:
yield handler.ActionResponse.success(
data={
'usage': usage,
},
)
@self.action(PluginToRuntimeAction.CALL_TOOL)
async def call_tool(data: dict[str, Any]) -> handler.ActionResponse:
+18 -3
View File
@@ -12,6 +12,19 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
LLM_USAGE_QUERY_VARIABLE = '_llm_usage'
STREAM_USAGE_QUERY_VARIABLE = '_stream_usage'
def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None:
"""Store the latest provider usage on the query for upstream action handlers."""
if query is None or not usage_info:
return
if query.variables is None:
query.variables = {}
query.variables[LLM_USAGE_QUERY_VARIABLE] = dict(usage_info)
class RuntimeProvider:
"""运行时模型提供商"""
@@ -67,6 +80,7 @@ class RuntimeProvider:
if isinstance(result, tuple):
msg, usage_info = result
if usage_info:
_store_llm_usage(query, usage_info)
input_tokens = usage_info.get('prompt_tokens', 0)
output_tokens = usage_info.get('completion_tokens', 0)
return msg
@@ -146,11 +160,12 @@ class RuntimeProvider:
if query:
if query.variables is None:
query.variables = {}
if '_stream_usage' in query.variables:
usage_info = query.variables['_stream_usage']
if STREAM_USAGE_QUERY_VARIABLE in query.variables:
usage_info = query.variables[STREAM_USAGE_QUERY_VARIABLE]
_store_llm_usage(query, usage_info)
input_tokens = usage_info.get('prompt_tokens', 0)
output_tokens = usage_info.get('completion_tokens', 0)
del query.variables['_stream_usage']
del query.variables[STREAM_USAGE_QUERY_VARIABLE]
except Exception as e:
status = 'error'
error_message = str(e)
@@ -250,32 +250,81 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
- dict with the same keys
- missing ``total_tokens`` (derived from prompt + completion)
- ``None`` / partially-populated usage (defaults to 0)
- provider-specific token details, including cache token counters
"""
if usage is None:
return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
def _plain_value(value: typing.Any) -> typing.Any:
if value is None:
return None
if isinstance(value, dict):
return {k: _plain_value(v) for k, v in value.items() if v is not None}
if isinstance(value, (list, tuple)):
return [_plain_value(v) for v in value]
def _get(key: str) -> typing.Any:
if isinstance(usage, dict):
return usage.get(key)
return getattr(usage, key, None)
model_dump = getattr(value, 'model_dump', None)
if callable(model_dump):
try:
dumped = model_dump()
if isinstance(dumped, dict):
return _plain_value(dumped)
except Exception:
pass
prompt_tokens = _get('prompt_tokens') or 0
completion_tokens = _get('completion_tokens') or 0
total_tokens = _get('total_tokens') or 0
return value
def _usage_dict(value: typing.Any) -> dict[str, typing.Any]:
if value is None:
return {}
plain = _plain_value(value)
if isinstance(plain, dict):
return plain
def _is_mock_attr(attr: typing.Any) -> bool:
return type(attr).__module__.startswith('unittest.mock')
data: dict[str, typing.Any] = {}
for key in (
'prompt_tokens',
'completion_tokens',
'total_tokens',
'prompt_tokens_details',
'completion_tokens_details',
'cache_creation_input_tokens',
'cache_read_input_tokens',
'input_token_details',
'output_token_details',
):
attr_value = getattr(value, key, None)
if attr_value is not None and not _is_mock_attr(attr_value):
data[key] = _plain_value(attr_value)
return data
def _to_int(value: typing.Any) -> int:
try:
return int(value or 0)
except (TypeError, ValueError):
return 0
normalized = _usage_dict(usage)
prompt_tokens = _to_int(normalized.get('prompt_tokens'))
completion_tokens = _to_int(normalized.get('completion_tokens'))
total_tokens = _to_int(normalized.get('total_tokens'))
# Some providers omit total_tokens in streaming usage; derive it.
if not total_tokens:
total_tokens = prompt_tokens + completion_tokens
return {
'prompt_tokens': int(prompt_tokens),
'completion_tokens': int(completion_tokens),
'total_tokens': int(total_tokens),
}
normalized['prompt_tokens'] = prompt_tokens
normalized['completion_tokens'] = completion_tokens
normalized['total_tokens'] = total_tokens
return normalized
def _extract_usage(self, response) -> dict:
def _extract_usage(self, response) -> dict | None:
"""Extract usage info from a non-streaming LiteLLM response."""
return self._normalize_usage(getattr(response, 'usage', None))
usage = getattr(response, 'usage', None)
if usage is None:
return None
return self._normalize_usage(usage)
@staticmethod
def _as_dict(value: typing.Any) -> dict:
@@ -474,7 +523,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
if query is not None:
if query.variables is None:
query.variables = {}
query.variables['_stream_usage'] = usage_info
query.variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage_info
if not hasattr(chunk, 'choices') or not chunk.choices:
continue