refactor(plugin): split agent-runner action handlers out of handler.py

Extract the AgentRunner Protocol v1 host-side surface from the giant
RuntimeConnectionHandler.__init__ into sibling modules using a registration-
function pattern (behavior-preserving; @h.action == @self.action):

- agent_run_support.py: shared constants + authorization/scope/projection helpers
- agent_pull_actions.py: register(h) for history/event pull APIs
- agent_runner_actions.py: register(h) for run/runtime/stats/claim lifecycle
- agent_state_actions.py: register(h) for steering/state APIs

__init__ now calls the three register(self) functions. handler.py keeps the
pre-existing plugin/llm/vector/knowledge handlers, get_prompt/call_tool/
get_tool_detail (coupled to retained helpers), shared helpers, and outbound
methods; it re-imports _validate_agent_run_session so external imports keep
working. handler.py: 4066 -> 1871 lines.

test_state_api_auth.py: repoint get_session_registry patch targets to
agent_run_support (the lookup moved modules). 385 agent unit tests pass; ruff clean.
This commit is contained in:
huanghuoguoguo
2026-06-22 13:08:34 +08:00
parent 4b34d4cffd
commit c7d4885bfc
6 changed files with 2309 additions and 2212 deletions
@@ -0,0 +1,293 @@
"""Agent-runner pull actions (history / event)."""
from __future__ import annotations
from typing import Any
from langbot_plugin.runtime.io import handler
from langbot_plugin.entities.io.actions.enums import (
PluginToRuntimeAction,
)
from .agent_run_support import (
_get_run_authorization,
_validate_agent_run_session,
_resolve_run_conversation,
_run_scope_filters,
_event_matches_run_scope,
_project_event_record_for_api,
)
def register(h):
@h.action(PluginToRuntimeAction.HISTORY_PAGE)
async def history_page(data: dict[str, Any]) -> handler.ActionResponse:
"""Page through transcript history for a conversation.
Requires run_id authorization. Only allows access to current run's conversation.
"""
run_id = data.get('run_id')
conversation_id = data.get('conversation_id')
before_cursor = data.get('before_cursor')
after_cursor = data.get('after_cursor')
limit = data.get('limit', 50)
direction = data.get('direction', 'backward')
include_attachments = data.get('include_attachments', False)
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'History page',
api_capability='history_page',
)
if error:
return error
conversation_id, scope_error = _resolve_run_conversation(
session,
conversation_id,
'History page',
)
if scope_error:
return scope_error
if not conversation_id:
return handler.ActionResponse.success(
data={
'items': [],
'next_cursor': None,
'prev_cursor': None,
'has_more': False,
}
)
# Parse cursors
before_seq = int(before_cursor) if before_cursor else None
after_seq = int(after_cursor) if after_cursor else None
# Query transcript
from ..agent.runner.transcript_store import TranscriptStore
store = TranscriptStore(h.ap.persistence_mgr.get_db_engine())
try:
items, next_seq, prev_seq, has_more = await store.page_transcript(
conversation_id=conversation_id,
before_seq=before_seq,
after_seq=after_seq,
limit=limit,
direction=direction,
include_attachments=include_attachments,
**_run_scope_filters(session),
)
return handler.ActionResponse.success(
data={
'items': items,
'next_cursor': str(next_seq) if next_seq else None,
'prev_cursor': str(prev_seq) if prev_seq else None,
'has_more': has_more,
}
)
except Exception as e:
h.ap.logger.error(f'HISTORY_PAGE error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'History page error: {e}')
@h.action(PluginToRuntimeAction.HISTORY_SEARCH)
async def history_search(data: dict[str, Any]) -> handler.ActionResponse:
"""Search transcript history.
Requires run_id authorization. Only searches current run's conversation.
Basic implementation using LIKE filtering.
"""
run_id = data.get('run_id')
query_text = data.get('query', '')
filters = data.get('filters') or {}
top_k = data.get('top_k', 10)
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'History search',
api_capability='history_search',
)
if error:
return error
requested_conversation_id = filters.get('conversation_id')
conversation_id, scope_error = _resolve_run_conversation(
session,
requested_conversation_id,
'History search',
)
if scope_error:
return scope_error
if not conversation_id:
return handler.ActionResponse.success(
data={
'items': [],
'total_count': 0,
'query': query_text,
}
)
# Search transcript
from ..agent.runner.transcript_store import TranscriptStore
store = TranscriptStore(h.ap.persistence_mgr.get_db_engine())
try:
safe_filters = {k: v for k, v in filters.items() if k != 'conversation_id'}
items = await store.search_transcript(
conversation_id=conversation_id,
query_text=query_text,
filters=safe_filters,
top_k=top_k,
**_run_scope_filters(session),
)
return handler.ActionResponse.success(
data={
'items': items,
'total_count': len(items),
'query': query_text,
}
)
except Exception as e:
h.ap.logger.error(f'HISTORY_SEARCH error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'History search error: {e}')
@h.action(PluginToRuntimeAction.EVENT_GET)
async def event_get(data: dict[str, Any]) -> handler.ActionResponse:
"""Get a single event record by ID.
Requires run_id authorization. Only allows access to events in current run's conversation.
"""
run_id = data.get('run_id')
event_id = data.get('event_id')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not event_id:
return handler.ActionResponse.error(message='event_id is required')
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'Event get',
api_capability='event_get',
)
if error:
return error
# Get event
from ..agent.runner.event_log_store import EventLogStore
store = EventLogStore(h.ap.persistence_mgr.get_db_engine())
try:
event = await store.get_event(event_id)
if not event:
return handler.ActionResponse.error(message=f'Event {event_id} not found')
# Validate event is in the same conversation as the run, or was created by the same run.
session_conversation_id = _get_run_authorization(session).get('conversation_id')
event_run_id = event.get('run_id')
if event_run_id and event_run_id == run_id:
return handler.ActionResponse.success(data=_project_event_record_for_api(event))
if not session_conversation_id or not _event_matches_run_scope(session, event):
return handler.ActionResponse.error(message=f'Event {event_id} is not accessible by this run')
return handler.ActionResponse.success(data=_project_event_record_for_api(event))
except Exception as e:
h.ap.logger.error(f'EVENT_GET error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'Event get error: {e}')
@h.action(PluginToRuntimeAction.EVENT_PAGE)
async def event_page(data: dict[str, Any]) -> handler.ActionResponse:
"""Page through event records.
Requires run_id authorization. Only allows access to current run's conversation.
"""
run_id = data.get('run_id')
conversation_id = data.get('conversation_id')
event_types = data.get('event_types')
before_cursor = data.get('before_cursor')
limit = data.get('limit', 50)
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'Event page',
api_capability='event_page',
)
if error:
return error
conversation_id, scope_error = _resolve_run_conversation(
session,
conversation_id,
'Event page',
)
if scope_error:
return scope_error
if not conversation_id:
return handler.ActionResponse.success(
data={
'items': [],
'next_cursor': None,
'prev_cursor': None,
'has_more': False,
}
)
# Parse cursor
before_seq = int(before_cursor) if before_cursor else None
# Query events
from ..agent.runner.event_log_store import EventLogStore
store = EventLogStore(h.ap.persistence_mgr.get_db_engine())
try:
items, next_seq, has_more = await store.page_events(
conversation_id=conversation_id,
event_types=event_types,
before_seq=before_seq,
limit=limit,
**_run_scope_filters(session),
)
return handler.ActionResponse.success(
data={
'items': [_project_event_record_for_api(item) for item in items],
'next_cursor': str(next_seq) if next_seq else None,
'prev_cursor': None,
'has_more': has_more,
}
)
except Exception as e:
h.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'Event page error: {e}')
+488
View File
@@ -0,0 +1,488 @@
"""Agent-runner protocol support: shared constants and authorization/scope/projection helpers extracted from handler.py."""
from __future__ import annotations
from typing import Any, Union
import json
import time
import sqlalchemy
from langbot_plugin.runtime.io import handler
from langbot_plugin.entities.io.actions.enums import (
PluginToRuntimeAction,
)
from ..core import app
from ..agent.runner.session_registry import get_session_registry
from ..agent.runner.result_normalizer import MAX_RESULT_SIZE_BYTES, STRICT_RESULT_PAYLOADS
class _RuntimeActionName:
def __init__(self, value: str):
self.value = value
AGENT_RUN_ADMIN_PERMISSION = 'agent_run:admin'
RUNTIME_ADMIN_PERMISSION = 'runtime:admin'
AGENT_RUNNER_ADMIN_PERMISSION = 'agent_runner:admin'
LEDGER_ONLY_SIDE_EFFECTING_RESULT_TYPES = {
'message.delta',
'message.completed',
'state.updated',
'run.completed',
'run.failed',
}
def _plugin_runtime_action(name: str, value: str) -> Any:
return getattr(PluginToRuntimeAction, name, _RuntimeActionName(value))
def _normalize_permission_set(value: Any) -> set[str]:
if isinstance(value, str):
return {permission.strip() for permission in value.split(',') if permission.strip()}
if isinstance(value, list):
return {str(item).strip() for item in value if str(item).strip()}
if isinstance(value, dict):
return {str(item).strip() for item, enabled in value.items() if enabled and str(item).strip()}
return set()
def _iter_agent_runner_admin_plugin_configs(ap: app.Application) -> list[dict[str, Any]]:
instance_config = getattr(ap, 'instance_config', None)
config_data = getattr(instance_config, 'data', {}) if instance_config is not None else {}
if not isinstance(config_data, dict):
return []
agent_runner_config = config_data.get('agent_runner', {})
if not isinstance(agent_runner_config, dict):
return []
raw_admin_plugins = agent_runner_config.get('admin_plugins', [])
if isinstance(raw_admin_plugins, dict):
items: list[dict[str, Any]] = []
for identity, entry in raw_admin_plugins.items():
if isinstance(entry, dict):
merged = dict(entry)
merged.setdefault('identity', identity)
items.append(merged)
else:
items.append({'identity': identity, 'permissions': entry})
return items
if isinstance(raw_admin_plugins, list):
return [item for item in raw_admin_plugins if isinstance(item, dict)]
return []
def _agent_runner_admin_permissions(ap: app.Application, plugin_identity: str | None) -> set[str]:
if not isinstance(plugin_identity, str) or not plugin_identity.strip():
return set()
normalized_identity = plugin_identity.strip()
permissions: set[str] = set()
for entry in _iter_agent_runner_admin_plugin_configs(ap):
if entry.get('enabled', True) is False:
continue
identity = entry.get('identity') or entry.get('plugin_identity') or entry.get('plugin') or entry.get('id')
if identity != normalized_identity:
continue
permissions.update(_normalize_permission_set(entry.get('permissions')))
permissions.update(_normalize_permission_set(entry.get('scopes')))
return permissions
def _has_agent_runner_admin_permission(
ap: app.Application,
plugin_identity: str | None,
permission: str,
) -> bool:
permissions = _agent_runner_admin_permissions(ap, plugin_identity)
if not permissions:
return False
domain = permission.split(':', 1)[0]
return bool(
permission in permissions
or f'{domain}:*' in permissions
or AGENT_RUNNER_ADMIN_PERMISSION in permissions
or '*' in permissions
)
def _deadline_seconds_from_payload(data: dict[str, Any], default: int = 60) -> int:
deadline_at = data.get('heartbeat_deadline_at')
if deadline_at is not None:
try:
return max(int(float(deadline_at) - time.time()), 1)
except (TypeError, ValueError):
pass
try:
return max(int(data.get('heartbeat_ttl_seconds') or default), 1)
except (TypeError, ValueError):
return default
def _get_run_authorization(session: dict[str, Any]) -> dict[str, Any]:
"""Return the run-scoped authorization snapshot."""
return session['authorization']
def _run_matches_run_scope(session: dict[str, Any], run: dict[str, Any]) -> bool:
authorization = _get_run_authorization(session)
session_run_id = session.get('run_id')
if run.get('run_id') == session_run_id:
return True
session_runner_id = session.get('runner_id') or authorization.get('runner_id')
if not session_runner_id or run.get('runner_id') != session_runner_id:
return False
if not authorization.get('conversation_id'):
return False
if run.get('conversation_id') != authorization.get('conversation_id'):
return False
if authorization.get('bot_id') is not None and authorization.get('bot_id') != run.get('bot_id'):
return False
if authorization.get('workspace_id') is not None and authorization.get('workspace_id') != run.get('workspace_id'):
return False
if authorization.get('thread_id') != run.get('thread_id'):
return False
return True
def _authorize_target_run(
session: dict[str, Any],
run: dict[str, Any],
) -> handler.ActionResponse | None:
"""Authorize non-admin target-run access against scope and runner owner."""
if _run_matches_run_scope(session, run):
return None
return handler.ActionResponse.error(message=f'Run {run.get("run_id")} is not accessible by this run')
def _validate_ledger_only_result_payload(
*,
ap: app.Application,
runner_id: str | None,
event_type: str,
data: dict[str, Any],
) -> str | None:
"""Validate result payloads that can be safely stored without side effects."""
try:
result_json = json.dumps({'type': event_type, 'data': data})
except (TypeError, ValueError) as exc:
return f'event data must be JSON serializable: {exc}'
if len(result_json) > MAX_RESULT_SIZE_BYTES:
return f'event payload exceeds {MAX_RESULT_SIZE_BYTES} bytes'
payload_model = STRICT_RESULT_PAYLOADS.get(event_type)
if payload_model is None:
return f'unknown result type: {event_type}'
try:
payload_model.model_validate(data)
except Exception as exc:
return f'invalid {event_type} payload: {exc}'
if event_type in LEDGER_ONLY_SIDE_EFFECTING_RESULT_TYPES:
if runner_id:
ap.logger.warning(
f'Runner {runner_id} attempted ledger-only append for side-effecting result type {event_type}'
)
return f'{event_type} must be emitted through the canonical runner result path'
return None
async def _require_runtime_write_ownership(
*,
store: Any,
session: dict[str, Any],
run: dict[str, Any],
data: dict[str, Any],
api_name: str,
) -> handler.ActionResponse | None:
"""Require current-run ownership or an active runtime claim for run writes."""
if run.get('run_id') == session.get('run_id') and run.get('status') != 'claimed':
return None
runtime_id = data.get('runtime_id')
claim_token = data.get('claim_token')
if not runtime_id or not claim_token:
return handler.ActionResponse.error(
message=f'{api_name} requires active claim ownership for target run {run.get("run_id")}'
)
if not await store.validate_active_claim(
run_id=str(run.get('run_id')),
runtime_id=str(runtime_id),
claim_token=str(claim_token),
):
return handler.ActionResponse.error(
message=f'{api_name} claim ownership is not active for target run {run.get("run_id")}'
)
return None
def _resolve_state_scope(
session: dict[str, Any],
scope: str,
) -> tuple[dict[str, Any] | None, str | None, handler.ActionResponse | None]:
"""Resolve state policy/context for an authorized run scope."""
authorization = _get_run_authorization(session)
state_policy = authorization['state_policy']
if not state_policy.get('enable_state', True):
return None, None, handler.ActionResponse.error(message='State access is disabled by binding policy')
state_scopes = state_policy.get('state_scopes', ['conversation', 'actor'])
if scope not in state_scopes:
return None, None, handler.ActionResponse.error(message=f'Scope "{scope}" is not enabled by binding policy')
state_context = authorization['state_context']
scope_key = state_context.get('scope_keys', {}).get(scope)
if not scope_key:
return None, None, handler.ActionResponse.error(message=f'Scope key not available for scope "{scope}"')
return state_context, scope_key, None
async def _validate_agent_run_session(
run_id: str,
caller_plugin_identity: str | None,
ap: app.Application,
api_name: str,
api_capability: str | None = None,
allow_persistent_authorization: bool = False,
admin_permission: str | None = None,
) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]:
"""Validate an AgentRunner pull API run session and run-scoped API access."""
if (
not run_id
and admin_permission
and _has_agent_runner_admin_permission(
ap,
caller_plugin_identity,
admin_permission,
)
):
return {
'run_id': run_id,
'runner_id': None,
'query_id': None,
'plugin_identity': caller_plugin_identity,
'authorization': {},
'status': {},
'steering_queue': [],
}, None
session_registry = get_session_registry()
session = await session_registry.get(run_id)
if not session:
if allow_persistent_authorization:
session = await _load_persistent_agent_run_session(run_id, ap, api_name)
if not session:
return None, handler.ActionResponse.error(message=f'Run session {run_id} not found or expired')
session_plugin_identity = session.get('plugin_identity')
if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip():
ap.logger.warning(f'{api_name}: run_id {run_id} has no plugin_identity')
return None, handler.ActionResponse.error(message=f'Run session {run_id} has no plugin_identity')
if not caller_plugin_identity:
return None, handler.ActionResponse.error(message=f'caller_plugin_identity is required for run_id {run_id}')
if caller_plugin_identity != session_plugin_identity:
ap.logger.warning(
f'{api_name}: caller_plugin_identity {caller_plugin_identity} '
f'does not match session plugin_identity {session_plugin_identity}'
)
return None, handler.ActionResponse.error(message=f'Plugin identity mismatch for run_id {run_id}')
if api_capability:
available_apis = _get_run_authorization(session).get('available_apis', {})
has_admin_permission = bool(admin_permission) and _has_agent_runner_admin_permission(
ap,
caller_plugin_identity,
admin_permission,
)
if not available_apis.get(api_capability, False) and not has_admin_permission:
return None, handler.ActionResponse.error(message=f'{api_name} access not authorized')
return session, None
async def _load_persistent_agent_run_session(
run_id: str,
ap: app.Application,
api_name: str,
) -> dict[str, Any] | None:
"""Load an expired run session from the AgentRun authorization snapshot."""
try:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
from ..entity.persistence.agent_run import AgentRun
engine = ap.persistence_mgr.get_db_engine()
session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with session_factory() as db_session:
result = await db_session.execute(sqlalchemy.select(AgentRun).where(AgentRun.run_id == run_id))
run = result.scalars().first()
except Exception as e:
ap.logger.error(f'{api_name}: failed to load persistent authorization for run_id {run_id}: {e}', exc_info=True)
return None
if run is None:
return None
try:
authorization = json.loads(run.authorization_json) if run.authorization_json else {}
except (TypeError, ValueError) as e:
ap.logger.warning(f'{api_name}: run_id {run_id} has invalid authorization_json: {e}')
return None
if not isinstance(authorization, dict):
ap.logger.warning(f'{api_name}: run_id {run_id} authorization_json is not an object')
return None
return {
'run_id': run.run_id,
'runner_id': authorization.get('runner_id') or run.runner_id,
'query_id': None,
'plugin_identity': authorization.get('plugin_identity'),
'authorization': authorization,
'status': {},
'steering_queue': [],
}
def _resolve_run_conversation(
session: dict[str, Any],
requested_conversation_id: str | None,
api_name: str,
) -> tuple[str | None, handler.ActionResponse | None]:
"""Resolve and enforce current-run conversation scope."""
session_conversation_id = _get_run_authorization(session).get('conversation_id')
if requested_conversation_id:
if not session_conversation_id:
return None, handler.ActionResponse.error(message=f'{api_name} is not available without a run conversation')
if requested_conversation_id != session_conversation_id:
return None, handler.ActionResponse.error(
message=f'Conversation {requested_conversation_id} is not accessible by this run'
)
return requested_conversation_id, None
return session_conversation_id, None
def _run_scope_filters(session: dict[str, Any]) -> dict[str, Any]:
authorization = _get_run_authorization(session)
return {
'bot_id': authorization.get('bot_id'),
'workspace_id': authorization.get('workspace_id'),
'thread_id': authorization.get('thread_id'),
'strict_thread': True,
}
def _run_ledger_scope_filters(session: dict[str, Any]) -> dict[str, Any]:
authorization = _get_run_authorization(session)
filters = _run_scope_filters(session)
filters['runner_id'] = session.get('runner_id') or authorization.get('runner_id')
return filters
def _event_matches_run_scope(session: dict[str, Any], event: dict[str, Any]) -> bool:
authorization = _get_run_authorization(session)
if authorization.get('conversation_id') != event.get('conversation_id'):
return False
if authorization.get('bot_id') is not None and authorization.get('bot_id') != event.get('bot_id'):
return False
if authorization.get('workspace_id') is not None and authorization.get('workspace_id') != event.get('workspace_id'):
return False
if authorization.get('thread_id') != event.get('thread_id'):
return False
return True
def _project_event_record_for_api(event: dict[str, Any]) -> dict[str, Any]:
"""Project EventLogStore rows onto the SDK AgentEventRecord DTO."""
seq = event.get('seq') or event.get('id')
return {
'event_id': event.get('event_id'),
'event_type': event.get('event_type'),
'event_time': event.get('event_time'),
'source': event.get('source'),
'bot_id': event.get('bot_id'),
'workspace_id': event.get('workspace_id'),
'conversation_id': event.get('conversation_id'),
'thread_id': event.get('thread_id'),
'actor_type': event.get('actor_type'),
'actor_id': event.get('actor_id'),
'actor_name': event.get('actor_name'),
'subject_type': event.get('subject_type'),
'subject_id': event.get('subject_id'),
'input_summary': event.get('input_summary'),
'input_ref': event.get('input_ref'),
'raw_ref': event.get('raw_ref'),
'seq': seq,
'cursor': event.get('cursor') or (str(seq) if seq is not None else None),
'created_at': event.get('created_at'),
'metadata': event.get('metadata') or {},
}
def _project_runner_descriptor_for_api(descriptor: Any) -> dict[str, Any]:
"""Project an AgentRunnerDescriptor-like object onto a JSON dict."""
if isinstance(descriptor, dict):
return dict(descriptor)
if hasattr(descriptor, 'model_dump'):
return descriptor.model_dump(mode='json')
return {
'id': getattr(descriptor, 'id', None),
'source': getattr(descriptor, 'source', None),
'label': getattr(descriptor, 'label', {}),
'description': getattr(descriptor, 'description', None),
'plugin_author': getattr(descriptor, 'plugin_author', None),
'plugin_name': getattr(descriptor, 'plugin_name', None),
'runner_name': getattr(descriptor, 'runner_name', None),
'plugin_version': getattr(descriptor, 'plugin_version', None),
'config_schema': getattr(descriptor, 'config_schema', []),
'capabilities': getattr(descriptor, 'capabilities', {}),
'permissions': getattr(descriptor, 'permissions', {}),
'raw_manifest': getattr(descriptor, 'raw_manifest', {}),
}
async def _record_agent_runner_admin_action(
ap: app.Application,
store: Any,
*,
action: str,
caller_plugin_identity: str | None,
permission: str,
durable_run_id: str | None = None,
target_runtime_id: str | None = None,
detail: dict[str, Any] | None = None,
) -> None:
"""Record a small audit trail for privileged AgentRunner operations."""
audit_data: dict[str, Any] = {
'action': action,
'caller_plugin_identity': caller_plugin_identity,
'permission': permission,
}
if durable_run_id:
audit_data['target_run_id'] = durable_run_id
if target_runtime_id:
audit_data['target_runtime_id'] = target_runtime_id
if detail:
audit_data['detail'] = detail
ap.logger.info('Agent runner admin action: %s', audit_data)
if not durable_run_id or store is None or not hasattr(store, 'append_audit_event'):
return
try:
await store.append_audit_event(
run_id=str(durable_run_id),
event_type=f'admin.{action}',
data=audit_data,
metadata={'permission': permission},
)
except Exception as exc:
ap.logger.warning(f'Failed to record AgentRunner admin audit event: {exc}', exc_info=True)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,316 @@
"""Agent-runner steering / state actions."""
from __future__ import annotations
from typing import Any
from langbot_plugin.runtime.io import handler
from langbot_plugin.entities.io.actions.enums import (
PluginToRuntimeAction,
)
from ..agent.runner.session_registry import get_session_registry
from .agent_run_support import (
_resolve_state_scope,
_validate_agent_run_session,
)
def register(h):
@h.action(PluginToRuntimeAction.STEERING_PULL)
async def steering_pull(data: dict[str, Any]) -> handler.ActionResponse:
"""Pull pending steering/follow-up inputs for the current run."""
run_id = data.get('run_id')
mode = data.get('mode', 'all')
limit = data.get('limit')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if limit is not None:
try:
limit = int(limit)
except (TypeError, ValueError):
return handler.ActionResponse.error(message='limit must be an integer')
if limit <= 0:
return handler.ActionResponse.error(message='limit must be > 0')
limit = min(limit, 100)
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'Steering pull',
api_capability='steering_pull',
)
if error:
return error
session_registry = get_session_registry()
items = await session_registry.pull_steering(
run_id,
mode=str(mode or 'all'),
limit=limit,
)
if items:
try:
from ..agent.runner.event_log_store import EventLogStore
store = EventLogStore(h.ap.persistence_mgr.get_db_engine())
for item in items:
event = item.get('event') if isinstance(item, dict) else None
conversation = item.get('conversation') if isinstance(item, dict) else None
actor = item.get('actor') if isinstance(item, dict) else None
subject = item.get('subject') if isinstance(item, dict) else None
if not isinstance(event, dict):
continue
await store.append_event(
event_id=None,
event_type='steering.injected',
source='agent_runner',
bot_id=conversation.get('bot_id') if isinstance(conversation, dict) else None,
workspace_id=conversation.get('workspace_id') if isinstance(conversation, dict) else None,
conversation_id=conversation.get('conversation_id') if isinstance(conversation, dict) else None,
thread_id=conversation.get('thread_id') if isinstance(conversation, dict) else None,
actor_type=actor.get('actor_type') if isinstance(actor, dict) else None,
actor_id=actor.get('actor_id') if isinstance(actor, dict) else None,
actor_name=actor.get('actor_name') if isinstance(actor, dict) else None,
subject_type=subject.get('subject_type') if isinstance(subject, dict) else None,
subject_id=subject.get('subject_id') if isinstance(subject, dict) else None,
input_summary=f'steering injected from {event.get("event_id")}',
run_id=run_id,
runner_id=session.get('runner_id') if isinstance(session, dict) else None,
metadata={
'steering': {
'status': 'injected',
'source_event_id': event.get('event_id'),
'claimed_by_run_id': item.get('claimed_run_id') if isinstance(item, dict) else run_id,
'claimed_runner_id': item.get('runner_id') if isinstance(item, dict) else None,
'claimed_at': item.get('claimed_at') if isinstance(item, dict) else None,
'pull_mode': str(mode or 'all'),
},
},
)
except Exception as exc:
h.ap.logger.warning(
f'Failed to write steering injection audit for run {run_id}: {exc}',
exc_info=True,
)
return handler.ActionResponse.success(data={'items': items})
# ================= State APIs (run-scoped, policy-enforced) =================
@h.action(PluginToRuntimeAction.STATE_GET)
async def state_get(data: dict[str, Any]) -> handler.ActionResponse:
"""Get a state value from host-owned state store.
Requires run_id authorization and scope enabled by state_policy.
"""
run_id = data.get('run_id')
scope = data.get('scope')
key = data.get('key')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not scope:
return handler.ActionResponse.error(message='scope is required')
if not key:
return handler.ActionResponse.error(message='key is required')
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'State get',
api_capability='state',
)
if error:
return error
_state_context, scope_key, state_error = _resolve_state_scope(session, scope)
if state_error:
return state_error
# Get state from persistent store
from ..agent.runner.persistent_state_store import get_persistent_state_store
store = get_persistent_state_store(h.ap.persistence_mgr.get_db_engine())
try:
value = await store.state_get(scope_key, key)
return handler.ActionResponse.success(data={'value': value})
except Exception as e:
h.ap.logger.error(f'STATE_GET error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'State get error: {e}')
@h.action(PluginToRuntimeAction.STATE_SET)
async def state_set(data: dict[str, Any]) -> handler.ActionResponse:
"""Set a state value in host-owned state store.
Requires run_id authorization and scope enabled by state_policy.
Value must be JSON-serializable and size-limited.
"""
run_id = data.get('run_id')
scope = data.get('scope')
key = data.get('key')
value = data.get('value')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not scope:
return handler.ActionResponse.error(message='scope is required')
if not key:
return handler.ActionResponse.error(message='key is required')
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'State set',
api_capability='state',
)
if error:
return error
state_context, scope_key, state_error = _resolve_state_scope(session, scope)
if state_error:
return state_error
# Get additional context for DB insert
runner_id = session.get('runner_id', '')
binding_identity = state_context.get('binding_identity', 'unknown')
# Set state in persistent store
from ..agent.runner.persistent_state_store import get_persistent_state_store
store = get_persistent_state_store(h.ap.persistence_mgr.get_db_engine())
try:
success, error = await store.state_set(
scope_key=scope_key,
state_key=key,
value=value,
runner_id=runner_id,
binding_identity=binding_identity,
scope=scope,
context=state_context,
logger=h.ap.logger,
)
if not success:
return handler.ActionResponse.error(message=error or 'Failed to set state')
return handler.ActionResponse.success(data={'success': True})
except Exception as e:
h.ap.logger.error(f'STATE_SET error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'State set error: {e}')
@h.action(PluginToRuntimeAction.STATE_DELETE)
async def state_delete(data: dict[str, Any]) -> handler.ActionResponse:
"""Delete a state value from host-owned state store.
Requires run_id authorization and scope enabled by state_policy.
"""
run_id = data.get('run_id')
scope = data.get('scope')
key = data.get('key')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not scope:
return handler.ActionResponse.error(message='scope is required')
if not key:
return handler.ActionResponse.error(message='key is required')
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'State delete',
api_capability='state',
)
if error:
return error
_state_context, scope_key, state_error = _resolve_state_scope(session, scope)
if state_error:
return state_error
# Delete state from persistent store
from ..agent.runner.persistent_state_store import get_persistent_state_store
store = get_persistent_state_store(h.ap.persistence_mgr.get_db_engine())
try:
deleted = await store.state_delete(scope_key, key)
return handler.ActionResponse.success(data={'success': deleted})
except Exception as e:
h.ap.logger.error(f'STATE_DELETE error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'State delete error: {e}')
@h.action(PluginToRuntimeAction.STATE_LIST)
async def state_list(data: dict[str, Any]) -> handler.ActionResponse:
"""List state keys in a scope.
Requires run_id authorization and scope enabled by state_policy.
"""
run_id = data.get('run_id')
scope = data.get('scope')
prefix = data.get('prefix')
limit = data.get('limit', 100)
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not scope:
return handler.ActionResponse.error(message='scope is required')
# Validate limit
if not isinstance(limit, int) or limit <= 0:
limit = 100
limit = min(limit, 100) # Cap at 100
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
h.ap,
'State list',
api_capability='state',
)
if error:
return error
_state_context, scope_key, state_error = _resolve_state_scope(session, scope)
if state_error:
return state_error
# List state keys from persistent store
from ..agent.runner.persistent_state_store import get_persistent_state_store
store = get_persistent_state_store(h.ap.persistence_mgr.get_db_engine())
try:
keys, has_more = await store.state_list(scope_key, prefix, limit)
return handler.ActionResponse.success(
data={
'keys': keys,
'has_more': has_more,
}
)
except Exception as e:
h.ap.logger.error(f'STATE_LIST error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'State list error: {e}')
File diff suppressed because it is too large Load Diff
+10 -10
View File
@@ -90,7 +90,7 @@ class TestStateAPIHandlerAuthorization:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
# Get the STATE_GET action handler (actions dict is keyed by action value string)
@@ -111,7 +111,7 @@ class TestStateAPIHandlerAuthorization:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
@@ -146,7 +146,7 @@ class TestStateAPIHandlerAuthorization:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
@@ -182,7 +182,7 @@ class TestStateAPIHandlerAuthorization:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
@@ -219,7 +219,7 @@ class TestStateAPIHandlerAuthorization:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
@@ -255,7 +255,7 @@ class TestStateAPIHandlerAuthorization:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
@@ -292,7 +292,7 @@ class TestStateAPIHandlerAuthorization:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
@@ -340,7 +340,7 @@ class TestStateAPIFullFlowWithRealDB:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
# Verify session has correct state_context
@@ -446,7 +446,7 @@ class TestStateHandlerReadsFromAuthorizationSnapshot:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
@@ -490,7 +490,7 @@ class TestStateHandlerReadsFromAuthorizationSnapshot:
async def fake_disconnect():
return True
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]