mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-22 05:24:23 +00:00
refactor(plugin): split agent-runner action handlers out of handler.py
Extract the AgentRunner Protocol v1 host-side surface from the giant RuntimeConnectionHandler.__init__ into sibling modules using a registration- function pattern (behavior-preserving; @h.action == @self.action): - agent_run_support.py: shared constants + authorization/scope/projection helpers - agent_pull_actions.py: register(h) for history/event pull APIs - agent_runner_actions.py: register(h) for run/runtime/stats/claim lifecycle - agent_state_actions.py: register(h) for steering/state APIs __init__ now calls the three register(self) functions. handler.py keeps the pre-existing plugin/llm/vector/knowledge handlers, get_prompt/call_tool/ get_tool_detail (coupled to retained helpers), shared helpers, and outbound methods; it re-imports _validate_agent_run_session so external imports keep working. handler.py: 4066 -> 1871 lines. test_state_api_auth.py: repoint get_session_registry patch targets to agent_run_support (the lookup moved modules). 385 agent unit tests pass; ruff clean.
This commit is contained in:
@@ -0,0 +1,293 @@
|
||||
"""Agent-runner pull actions (history / event)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
from langbot_plugin.runtime.io import handler
|
||||
from langbot_plugin.entities.io.actions.enums import (
|
||||
PluginToRuntimeAction,
|
||||
)
|
||||
|
||||
|
||||
from .agent_run_support import (
|
||||
_get_run_authorization,
|
||||
_validate_agent_run_session,
|
||||
_resolve_run_conversation,
|
||||
_run_scope_filters,
|
||||
_event_matches_run_scope,
|
||||
_project_event_record_for_api,
|
||||
)
|
||||
|
||||
|
||||
def register(h):
|
||||
@h.action(PluginToRuntimeAction.HISTORY_PAGE)
|
||||
async def history_page(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Page through transcript history for a conversation.
|
||||
|
||||
Requires run_id authorization. Only allows access to current run's conversation.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
conversation_id = data.get('conversation_id')
|
||||
before_cursor = data.get('before_cursor')
|
||||
after_cursor = data.get('after_cursor')
|
||||
limit = data.get('limit', 50)
|
||||
direction = data.get('direction', 'backward')
|
||||
include_attachments = data.get('include_attachments', False)
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'History page',
|
||||
api_capability='history_page',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
conversation_id, scope_error = _resolve_run_conversation(
|
||||
session,
|
||||
conversation_id,
|
||||
'History page',
|
||||
)
|
||||
if scope_error:
|
||||
return scope_error
|
||||
|
||||
if not conversation_id:
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'items': [],
|
||||
'next_cursor': None,
|
||||
'prev_cursor': None,
|
||||
'has_more': False,
|
||||
}
|
||||
)
|
||||
|
||||
# Parse cursors
|
||||
before_seq = int(before_cursor) if before_cursor else None
|
||||
after_seq = int(after_cursor) if after_cursor else None
|
||||
|
||||
# Query transcript
|
||||
from ..agent.runner.transcript_store import TranscriptStore
|
||||
|
||||
store = TranscriptStore(h.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id=conversation_id,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
limit=limit,
|
||||
direction=direction,
|
||||
include_attachments=include_attachments,
|
||||
**_run_scope_filters(session),
|
||||
)
|
||||
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'items': items,
|
||||
'next_cursor': str(next_seq) if next_seq else None,
|
||||
'prev_cursor': str(prev_seq) if prev_seq else None,
|
||||
'has_more': has_more,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
h.ap.logger.error(f'HISTORY_PAGE error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'History page error: {e}')
|
||||
|
||||
@h.action(PluginToRuntimeAction.HISTORY_SEARCH)
|
||||
async def history_search(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Search transcript history.
|
||||
|
||||
Requires run_id authorization. Only searches current run's conversation.
|
||||
Basic implementation using LIKE filtering.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
query_text = data.get('query', '')
|
||||
filters = data.get('filters') or {}
|
||||
top_k = data.get('top_k', 10)
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'History search',
|
||||
api_capability='history_search',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
requested_conversation_id = filters.get('conversation_id')
|
||||
conversation_id, scope_error = _resolve_run_conversation(
|
||||
session,
|
||||
requested_conversation_id,
|
||||
'History search',
|
||||
)
|
||||
if scope_error:
|
||||
return scope_error
|
||||
|
||||
if not conversation_id:
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'items': [],
|
||||
'total_count': 0,
|
||||
'query': query_text,
|
||||
}
|
||||
)
|
||||
|
||||
# Search transcript
|
||||
from ..agent.runner.transcript_store import TranscriptStore
|
||||
|
||||
store = TranscriptStore(h.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
safe_filters = {k: v for k, v in filters.items() if k != 'conversation_id'}
|
||||
items = await store.search_transcript(
|
||||
conversation_id=conversation_id,
|
||||
query_text=query_text,
|
||||
filters=safe_filters,
|
||||
top_k=top_k,
|
||||
**_run_scope_filters(session),
|
||||
)
|
||||
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'items': items,
|
||||
'total_count': len(items),
|
||||
'query': query_text,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
h.ap.logger.error(f'HISTORY_SEARCH error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'History search error: {e}')
|
||||
|
||||
@h.action(PluginToRuntimeAction.EVENT_GET)
|
||||
async def event_get(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Get a single event record by ID.
|
||||
|
||||
Requires run_id authorization. Only allows access to events in current run's conversation.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
event_id = data.get('event_id')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if not event_id:
|
||||
return handler.ActionResponse.error(message='event_id is required')
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'Event get',
|
||||
api_capability='event_get',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
# Get event
|
||||
from ..agent.runner.event_log_store import EventLogStore
|
||||
|
||||
store = EventLogStore(h.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
event = await store.get_event(event_id)
|
||||
if not event:
|
||||
return handler.ActionResponse.error(message=f'Event {event_id} not found')
|
||||
|
||||
# Validate event is in the same conversation as the run, or was created by the same run.
|
||||
session_conversation_id = _get_run_authorization(session).get('conversation_id')
|
||||
event_run_id = event.get('run_id')
|
||||
if event_run_id and event_run_id == run_id:
|
||||
return handler.ActionResponse.success(data=_project_event_record_for_api(event))
|
||||
if not session_conversation_id or not _event_matches_run_scope(session, event):
|
||||
return handler.ActionResponse.error(message=f'Event {event_id} is not accessible by this run')
|
||||
|
||||
return handler.ActionResponse.success(data=_project_event_record_for_api(event))
|
||||
except Exception as e:
|
||||
h.ap.logger.error(f'EVENT_GET error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'Event get error: {e}')
|
||||
|
||||
@h.action(PluginToRuntimeAction.EVENT_PAGE)
|
||||
async def event_page(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Page through event records.
|
||||
|
||||
Requires run_id authorization. Only allows access to current run's conversation.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
conversation_id = data.get('conversation_id')
|
||||
event_types = data.get('event_types')
|
||||
before_cursor = data.get('before_cursor')
|
||||
limit = data.get('limit', 50)
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'Event page',
|
||||
api_capability='event_page',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
conversation_id, scope_error = _resolve_run_conversation(
|
||||
session,
|
||||
conversation_id,
|
||||
'Event page',
|
||||
)
|
||||
if scope_error:
|
||||
return scope_error
|
||||
|
||||
if not conversation_id:
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'items': [],
|
||||
'next_cursor': None,
|
||||
'prev_cursor': None,
|
||||
'has_more': False,
|
||||
}
|
||||
)
|
||||
|
||||
# Parse cursor
|
||||
before_seq = int(before_cursor) if before_cursor else None
|
||||
|
||||
# Query events
|
||||
from ..agent.runner.event_log_store import EventLogStore
|
||||
|
||||
store = EventLogStore(h.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
items, next_seq, has_more = await store.page_events(
|
||||
conversation_id=conversation_id,
|
||||
event_types=event_types,
|
||||
before_seq=before_seq,
|
||||
limit=limit,
|
||||
**_run_scope_filters(session),
|
||||
)
|
||||
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'items': [_project_event_record_for_api(item) for item in items],
|
||||
'next_cursor': str(next_seq) if next_seq else None,
|
||||
'prev_cursor': None,
|
||||
'has_more': has_more,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
h.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'Event page error: {e}')
|
||||
@@ -0,0 +1,488 @@
|
||||
"""Agent-runner protocol support: shared constants and authorization/scope/projection helpers extracted from handler.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Union
|
||||
import json
|
||||
import time
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from langbot_plugin.runtime.io import handler
|
||||
from langbot_plugin.entities.io.actions.enums import (
|
||||
PluginToRuntimeAction,
|
||||
)
|
||||
|
||||
|
||||
from ..core import app
|
||||
from ..agent.runner.session_registry import get_session_registry
|
||||
from ..agent.runner.result_normalizer import MAX_RESULT_SIZE_BYTES, STRICT_RESULT_PAYLOADS
|
||||
|
||||
|
||||
class _RuntimeActionName:
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
|
||||
AGENT_RUN_ADMIN_PERMISSION = 'agent_run:admin'
|
||||
RUNTIME_ADMIN_PERMISSION = 'runtime:admin'
|
||||
AGENT_RUNNER_ADMIN_PERMISSION = 'agent_runner:admin'
|
||||
LEDGER_ONLY_SIDE_EFFECTING_RESULT_TYPES = {
|
||||
'message.delta',
|
||||
'message.completed',
|
||||
'state.updated',
|
||||
'run.completed',
|
||||
'run.failed',
|
||||
}
|
||||
|
||||
|
||||
def _plugin_runtime_action(name: str, value: str) -> Any:
|
||||
return getattr(PluginToRuntimeAction, name, _RuntimeActionName(value))
|
||||
|
||||
|
||||
def _normalize_permission_set(value: Any) -> set[str]:
|
||||
if isinstance(value, str):
|
||||
return {permission.strip() for permission in value.split(',') if permission.strip()}
|
||||
if isinstance(value, list):
|
||||
return {str(item).strip() for item in value if str(item).strip()}
|
||||
if isinstance(value, dict):
|
||||
return {str(item).strip() for item, enabled in value.items() if enabled and str(item).strip()}
|
||||
return set()
|
||||
|
||||
|
||||
def _iter_agent_runner_admin_plugin_configs(ap: app.Application) -> list[dict[str, Any]]:
|
||||
instance_config = getattr(ap, 'instance_config', None)
|
||||
config_data = getattr(instance_config, 'data', {}) if instance_config is not None else {}
|
||||
if not isinstance(config_data, dict):
|
||||
return []
|
||||
agent_runner_config = config_data.get('agent_runner', {})
|
||||
if not isinstance(agent_runner_config, dict):
|
||||
return []
|
||||
raw_admin_plugins = agent_runner_config.get('admin_plugins', [])
|
||||
if isinstance(raw_admin_plugins, dict):
|
||||
items: list[dict[str, Any]] = []
|
||||
for identity, entry in raw_admin_plugins.items():
|
||||
if isinstance(entry, dict):
|
||||
merged = dict(entry)
|
||||
merged.setdefault('identity', identity)
|
||||
items.append(merged)
|
||||
else:
|
||||
items.append({'identity': identity, 'permissions': entry})
|
||||
return items
|
||||
if isinstance(raw_admin_plugins, list):
|
||||
return [item for item in raw_admin_plugins if isinstance(item, dict)]
|
||||
return []
|
||||
|
||||
|
||||
def _agent_runner_admin_permissions(ap: app.Application, plugin_identity: str | None) -> set[str]:
|
||||
if not isinstance(plugin_identity, str) or not plugin_identity.strip():
|
||||
return set()
|
||||
normalized_identity = plugin_identity.strip()
|
||||
permissions: set[str] = set()
|
||||
for entry in _iter_agent_runner_admin_plugin_configs(ap):
|
||||
if entry.get('enabled', True) is False:
|
||||
continue
|
||||
identity = entry.get('identity') or entry.get('plugin_identity') or entry.get('plugin') or entry.get('id')
|
||||
if identity != normalized_identity:
|
||||
continue
|
||||
permissions.update(_normalize_permission_set(entry.get('permissions')))
|
||||
permissions.update(_normalize_permission_set(entry.get('scopes')))
|
||||
return permissions
|
||||
|
||||
|
||||
def _has_agent_runner_admin_permission(
|
||||
ap: app.Application,
|
||||
plugin_identity: str | None,
|
||||
permission: str,
|
||||
) -> bool:
|
||||
permissions = _agent_runner_admin_permissions(ap, plugin_identity)
|
||||
if not permissions:
|
||||
return False
|
||||
domain = permission.split(':', 1)[0]
|
||||
return bool(
|
||||
permission in permissions
|
||||
or f'{domain}:*' in permissions
|
||||
or AGENT_RUNNER_ADMIN_PERMISSION in permissions
|
||||
or '*' in permissions
|
||||
)
|
||||
|
||||
|
||||
def _deadline_seconds_from_payload(data: dict[str, Any], default: int = 60) -> int:
|
||||
deadline_at = data.get('heartbeat_deadline_at')
|
||||
if deadline_at is not None:
|
||||
try:
|
||||
return max(int(float(deadline_at) - time.time()), 1)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
try:
|
||||
return max(int(data.get('heartbeat_ttl_seconds') or default), 1)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _get_run_authorization(session: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return the run-scoped authorization snapshot."""
|
||||
return session['authorization']
|
||||
|
||||
|
||||
def _run_matches_run_scope(session: dict[str, Any], run: dict[str, Any]) -> bool:
|
||||
authorization = _get_run_authorization(session)
|
||||
session_run_id = session.get('run_id')
|
||||
if run.get('run_id') == session_run_id:
|
||||
return True
|
||||
session_runner_id = session.get('runner_id') or authorization.get('runner_id')
|
||||
if not session_runner_id or run.get('runner_id') != session_runner_id:
|
||||
return False
|
||||
if not authorization.get('conversation_id'):
|
||||
return False
|
||||
if run.get('conversation_id') != authorization.get('conversation_id'):
|
||||
return False
|
||||
if authorization.get('bot_id') is not None and authorization.get('bot_id') != run.get('bot_id'):
|
||||
return False
|
||||
if authorization.get('workspace_id') is not None and authorization.get('workspace_id') != run.get('workspace_id'):
|
||||
return False
|
||||
if authorization.get('thread_id') != run.get('thread_id'):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _authorize_target_run(
|
||||
session: dict[str, Any],
|
||||
run: dict[str, Any],
|
||||
) -> handler.ActionResponse | None:
|
||||
"""Authorize non-admin target-run access against scope and runner owner."""
|
||||
if _run_matches_run_scope(session, run):
|
||||
return None
|
||||
return handler.ActionResponse.error(message=f'Run {run.get("run_id")} is not accessible by this run')
|
||||
|
||||
|
||||
def _validate_ledger_only_result_payload(
|
||||
*,
|
||||
ap: app.Application,
|
||||
runner_id: str | None,
|
||||
event_type: str,
|
||||
data: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""Validate result payloads that can be safely stored without side effects."""
|
||||
try:
|
||||
result_json = json.dumps({'type': event_type, 'data': data})
|
||||
except (TypeError, ValueError) as exc:
|
||||
return f'event data must be JSON serializable: {exc}'
|
||||
if len(result_json) > MAX_RESULT_SIZE_BYTES:
|
||||
return f'event payload exceeds {MAX_RESULT_SIZE_BYTES} bytes'
|
||||
|
||||
payload_model = STRICT_RESULT_PAYLOADS.get(event_type)
|
||||
if payload_model is None:
|
||||
return f'unknown result type: {event_type}'
|
||||
try:
|
||||
payload_model.model_validate(data)
|
||||
except Exception as exc:
|
||||
return f'invalid {event_type} payload: {exc}'
|
||||
|
||||
if event_type in LEDGER_ONLY_SIDE_EFFECTING_RESULT_TYPES:
|
||||
if runner_id:
|
||||
ap.logger.warning(
|
||||
f'Runner {runner_id} attempted ledger-only append for side-effecting result type {event_type}'
|
||||
)
|
||||
return f'{event_type} must be emitted through the canonical runner result path'
|
||||
return None
|
||||
|
||||
|
||||
async def _require_runtime_write_ownership(
|
||||
*,
|
||||
store: Any,
|
||||
session: dict[str, Any],
|
||||
run: dict[str, Any],
|
||||
data: dict[str, Any],
|
||||
api_name: str,
|
||||
) -> handler.ActionResponse | None:
|
||||
"""Require current-run ownership or an active runtime claim for run writes."""
|
||||
if run.get('run_id') == session.get('run_id') and run.get('status') != 'claimed':
|
||||
return None
|
||||
|
||||
runtime_id = data.get('runtime_id')
|
||||
claim_token = data.get('claim_token')
|
||||
if not runtime_id or not claim_token:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'{api_name} requires active claim ownership for target run {run.get("run_id")}'
|
||||
)
|
||||
|
||||
if not await store.validate_active_claim(
|
||||
run_id=str(run.get('run_id')),
|
||||
runtime_id=str(runtime_id),
|
||||
claim_token=str(claim_token),
|
||||
):
|
||||
return handler.ActionResponse.error(
|
||||
message=f'{api_name} claim ownership is not active for target run {run.get("run_id")}'
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_state_scope(
|
||||
session: dict[str, Any],
|
||||
scope: str,
|
||||
) -> tuple[dict[str, Any] | None, str | None, handler.ActionResponse | None]:
|
||||
"""Resolve state policy/context for an authorized run scope."""
|
||||
authorization = _get_run_authorization(session)
|
||||
state_policy = authorization['state_policy']
|
||||
|
||||
if not state_policy.get('enable_state', True):
|
||||
return None, None, handler.ActionResponse.error(message='State access is disabled by binding policy')
|
||||
|
||||
state_scopes = state_policy.get('state_scopes', ['conversation', 'actor'])
|
||||
if scope not in state_scopes:
|
||||
return None, None, handler.ActionResponse.error(message=f'Scope "{scope}" is not enabled by binding policy')
|
||||
|
||||
state_context = authorization['state_context']
|
||||
scope_key = state_context.get('scope_keys', {}).get(scope)
|
||||
if not scope_key:
|
||||
return None, None, handler.ActionResponse.error(message=f'Scope key not available for scope "{scope}"')
|
||||
|
||||
return state_context, scope_key, None
|
||||
|
||||
|
||||
async def _validate_agent_run_session(
|
||||
run_id: str,
|
||||
caller_plugin_identity: str | None,
|
||||
ap: app.Application,
|
||||
api_name: str,
|
||||
api_capability: str | None = None,
|
||||
allow_persistent_authorization: bool = False,
|
||||
admin_permission: str | None = None,
|
||||
) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]:
|
||||
"""Validate an AgentRunner pull API run session and run-scoped API access."""
|
||||
if (
|
||||
not run_id
|
||||
and admin_permission
|
||||
and _has_agent_runner_admin_permission(
|
||||
ap,
|
||||
caller_plugin_identity,
|
||||
admin_permission,
|
||||
)
|
||||
):
|
||||
return {
|
||||
'run_id': run_id,
|
||||
'runner_id': None,
|
||||
'query_id': None,
|
||||
'plugin_identity': caller_plugin_identity,
|
||||
'authorization': {},
|
||||
'status': {},
|
||||
'steering_queue': [],
|
||||
}, None
|
||||
|
||||
session_registry = get_session_registry()
|
||||
session = await session_registry.get(run_id)
|
||||
if not session:
|
||||
if allow_persistent_authorization:
|
||||
session = await _load_persistent_agent_run_session(run_id, ap, api_name)
|
||||
if not session:
|
||||
return None, handler.ActionResponse.error(message=f'Run session {run_id} not found or expired')
|
||||
|
||||
session_plugin_identity = session.get('plugin_identity')
|
||||
if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip():
|
||||
ap.logger.warning(f'{api_name}: run_id {run_id} has no plugin_identity')
|
||||
return None, handler.ActionResponse.error(message=f'Run session {run_id} has no plugin_identity')
|
||||
if not caller_plugin_identity:
|
||||
return None, handler.ActionResponse.error(message=f'caller_plugin_identity is required for run_id {run_id}')
|
||||
if caller_plugin_identity != session_plugin_identity:
|
||||
ap.logger.warning(
|
||||
f'{api_name}: caller_plugin_identity {caller_plugin_identity} '
|
||||
f'does not match session plugin_identity {session_plugin_identity}'
|
||||
)
|
||||
return None, handler.ActionResponse.error(message=f'Plugin identity mismatch for run_id {run_id}')
|
||||
|
||||
if api_capability:
|
||||
available_apis = _get_run_authorization(session).get('available_apis', {})
|
||||
has_admin_permission = bool(admin_permission) and _has_agent_runner_admin_permission(
|
||||
ap,
|
||||
caller_plugin_identity,
|
||||
admin_permission,
|
||||
)
|
||||
if not available_apis.get(api_capability, False) and not has_admin_permission:
|
||||
return None, handler.ActionResponse.error(message=f'{api_name} access not authorized')
|
||||
|
||||
return session, None
|
||||
|
||||
|
||||
async def _load_persistent_agent_run_session(
|
||||
run_id: str,
|
||||
ap: app.Application,
|
||||
api_name: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Load an expired run session from the AgentRun authorization snapshot."""
|
||||
try:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from ..entity.persistence.agent_run import AgentRun
|
||||
|
||||
engine = ap.persistence_mgr.get_db_engine()
|
||||
session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as db_session:
|
||||
result = await db_session.execute(sqlalchemy.select(AgentRun).where(AgentRun.run_id == run_id))
|
||||
run = result.scalars().first()
|
||||
except Exception as e:
|
||||
ap.logger.error(f'{api_name}: failed to load persistent authorization for run_id {run_id}: {e}', exc_info=True)
|
||||
return None
|
||||
|
||||
if run is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
authorization = json.loads(run.authorization_json) if run.authorization_json else {}
|
||||
except (TypeError, ValueError) as e:
|
||||
ap.logger.warning(f'{api_name}: run_id {run_id} has invalid authorization_json: {e}')
|
||||
return None
|
||||
|
||||
if not isinstance(authorization, dict):
|
||||
ap.logger.warning(f'{api_name}: run_id {run_id} authorization_json is not an object')
|
||||
return None
|
||||
|
||||
return {
|
||||
'run_id': run.run_id,
|
||||
'runner_id': authorization.get('runner_id') or run.runner_id,
|
||||
'query_id': None,
|
||||
'plugin_identity': authorization.get('plugin_identity'),
|
||||
'authorization': authorization,
|
||||
'status': {},
|
||||
'steering_queue': [],
|
||||
}
|
||||
|
||||
|
||||
def _resolve_run_conversation(
|
||||
session: dict[str, Any],
|
||||
requested_conversation_id: str | None,
|
||||
api_name: str,
|
||||
) -> tuple[str | None, handler.ActionResponse | None]:
|
||||
"""Resolve and enforce current-run conversation scope."""
|
||||
session_conversation_id = _get_run_authorization(session).get('conversation_id')
|
||||
|
||||
if requested_conversation_id:
|
||||
if not session_conversation_id:
|
||||
return None, handler.ActionResponse.error(message=f'{api_name} is not available without a run conversation')
|
||||
if requested_conversation_id != session_conversation_id:
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Conversation {requested_conversation_id} is not accessible by this run'
|
||||
)
|
||||
return requested_conversation_id, None
|
||||
|
||||
return session_conversation_id, None
|
||||
|
||||
|
||||
def _run_scope_filters(session: dict[str, Any]) -> dict[str, Any]:
|
||||
authorization = _get_run_authorization(session)
|
||||
return {
|
||||
'bot_id': authorization.get('bot_id'),
|
||||
'workspace_id': authorization.get('workspace_id'),
|
||||
'thread_id': authorization.get('thread_id'),
|
||||
'strict_thread': True,
|
||||
}
|
||||
|
||||
|
||||
def _run_ledger_scope_filters(session: dict[str, Any]) -> dict[str, Any]:
|
||||
authorization = _get_run_authorization(session)
|
||||
filters = _run_scope_filters(session)
|
||||
filters['runner_id'] = session.get('runner_id') or authorization.get('runner_id')
|
||||
return filters
|
||||
|
||||
|
||||
def _event_matches_run_scope(session: dict[str, Any], event: dict[str, Any]) -> bool:
|
||||
authorization = _get_run_authorization(session)
|
||||
if authorization.get('conversation_id') != event.get('conversation_id'):
|
||||
return False
|
||||
if authorization.get('bot_id') is not None and authorization.get('bot_id') != event.get('bot_id'):
|
||||
return False
|
||||
if authorization.get('workspace_id') is not None and authorization.get('workspace_id') != event.get('workspace_id'):
|
||||
return False
|
||||
if authorization.get('thread_id') != event.get('thread_id'):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _project_event_record_for_api(event: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Project EventLogStore rows onto the SDK AgentEventRecord DTO."""
|
||||
seq = event.get('seq') or event.get('id')
|
||||
return {
|
||||
'event_id': event.get('event_id'),
|
||||
'event_type': event.get('event_type'),
|
||||
'event_time': event.get('event_time'),
|
||||
'source': event.get('source'),
|
||||
'bot_id': event.get('bot_id'),
|
||||
'workspace_id': event.get('workspace_id'),
|
||||
'conversation_id': event.get('conversation_id'),
|
||||
'thread_id': event.get('thread_id'),
|
||||
'actor_type': event.get('actor_type'),
|
||||
'actor_id': event.get('actor_id'),
|
||||
'actor_name': event.get('actor_name'),
|
||||
'subject_type': event.get('subject_type'),
|
||||
'subject_id': event.get('subject_id'),
|
||||
'input_summary': event.get('input_summary'),
|
||||
'input_ref': event.get('input_ref'),
|
||||
'raw_ref': event.get('raw_ref'),
|
||||
'seq': seq,
|
||||
'cursor': event.get('cursor') or (str(seq) if seq is not None else None),
|
||||
'created_at': event.get('created_at'),
|
||||
'metadata': event.get('metadata') or {},
|
||||
}
|
||||
|
||||
|
||||
def _project_runner_descriptor_for_api(descriptor: Any) -> dict[str, Any]:
|
||||
"""Project an AgentRunnerDescriptor-like object onto a JSON dict."""
|
||||
if isinstance(descriptor, dict):
|
||||
return dict(descriptor)
|
||||
if hasattr(descriptor, 'model_dump'):
|
||||
return descriptor.model_dump(mode='json')
|
||||
return {
|
||||
'id': getattr(descriptor, 'id', None),
|
||||
'source': getattr(descriptor, 'source', None),
|
||||
'label': getattr(descriptor, 'label', {}),
|
||||
'description': getattr(descriptor, 'description', None),
|
||||
'plugin_author': getattr(descriptor, 'plugin_author', None),
|
||||
'plugin_name': getattr(descriptor, 'plugin_name', None),
|
||||
'runner_name': getattr(descriptor, 'runner_name', None),
|
||||
'plugin_version': getattr(descriptor, 'plugin_version', None),
|
||||
'config_schema': getattr(descriptor, 'config_schema', []),
|
||||
'capabilities': getattr(descriptor, 'capabilities', {}),
|
||||
'permissions': getattr(descriptor, 'permissions', {}),
|
||||
'raw_manifest': getattr(descriptor, 'raw_manifest', {}),
|
||||
}
|
||||
|
||||
|
||||
async def _record_agent_runner_admin_action(
|
||||
ap: app.Application,
|
||||
store: Any,
|
||||
*,
|
||||
action: str,
|
||||
caller_plugin_identity: str | None,
|
||||
permission: str,
|
||||
durable_run_id: str | None = None,
|
||||
target_runtime_id: str | None = None,
|
||||
detail: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Record a small audit trail for privileged AgentRunner operations."""
|
||||
audit_data: dict[str, Any] = {
|
||||
'action': action,
|
||||
'caller_plugin_identity': caller_plugin_identity,
|
||||
'permission': permission,
|
||||
}
|
||||
if durable_run_id:
|
||||
audit_data['target_run_id'] = durable_run_id
|
||||
if target_runtime_id:
|
||||
audit_data['target_runtime_id'] = target_runtime_id
|
||||
if detail:
|
||||
audit_data['detail'] = detail
|
||||
|
||||
ap.logger.info('Agent runner admin action: %s', audit_data)
|
||||
if not durable_run_id or store is None or not hasattr(store, 'append_audit_event'):
|
||||
return
|
||||
|
||||
try:
|
||||
await store.append_audit_event(
|
||||
run_id=str(durable_run_id),
|
||||
event_type=f'admin.{action}',
|
||||
data=audit_data,
|
||||
metadata={'permission': permission},
|
||||
)
|
||||
except Exception as exc:
|
||||
ap.logger.warning(f'Failed to record AgentRunner admin audit event: {exc}', exc_info=True)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,316 @@
|
||||
"""Agent-runner steering / state actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
from langbot_plugin.runtime.io import handler
|
||||
from langbot_plugin.entities.io.actions.enums import (
|
||||
PluginToRuntimeAction,
|
||||
)
|
||||
|
||||
|
||||
from ..agent.runner.session_registry import get_session_registry
|
||||
|
||||
from .agent_run_support import (
|
||||
_resolve_state_scope,
|
||||
_validate_agent_run_session,
|
||||
)
|
||||
|
||||
|
||||
def register(h):
|
||||
@h.action(PluginToRuntimeAction.STEERING_PULL)
|
||||
async def steering_pull(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Pull pending steering/follow-up inputs for the current run."""
|
||||
run_id = data.get('run_id')
|
||||
mode = data.get('mode', 'all')
|
||||
limit = data.get('limit')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if limit is not None:
|
||||
try:
|
||||
limit = int(limit)
|
||||
except (TypeError, ValueError):
|
||||
return handler.ActionResponse.error(message='limit must be an integer')
|
||||
if limit <= 0:
|
||||
return handler.ActionResponse.error(message='limit must be > 0')
|
||||
limit = min(limit, 100)
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'Steering pull',
|
||||
api_capability='steering_pull',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
session_registry = get_session_registry()
|
||||
items = await session_registry.pull_steering(
|
||||
run_id,
|
||||
mode=str(mode or 'all'),
|
||||
limit=limit,
|
||||
)
|
||||
if items:
|
||||
try:
|
||||
from ..agent.runner.event_log_store import EventLogStore
|
||||
|
||||
store = EventLogStore(h.ap.persistence_mgr.get_db_engine())
|
||||
for item in items:
|
||||
event = item.get('event') if isinstance(item, dict) else None
|
||||
conversation = item.get('conversation') if isinstance(item, dict) else None
|
||||
actor = item.get('actor') if isinstance(item, dict) else None
|
||||
subject = item.get('subject') if isinstance(item, dict) else None
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
await store.append_event(
|
||||
event_id=None,
|
||||
event_type='steering.injected',
|
||||
source='agent_runner',
|
||||
bot_id=conversation.get('bot_id') if isinstance(conversation, dict) else None,
|
||||
workspace_id=conversation.get('workspace_id') if isinstance(conversation, dict) else None,
|
||||
conversation_id=conversation.get('conversation_id') if isinstance(conversation, dict) else None,
|
||||
thread_id=conversation.get('thread_id') if isinstance(conversation, dict) else None,
|
||||
actor_type=actor.get('actor_type') if isinstance(actor, dict) else None,
|
||||
actor_id=actor.get('actor_id') if isinstance(actor, dict) else None,
|
||||
actor_name=actor.get('actor_name') if isinstance(actor, dict) else None,
|
||||
subject_type=subject.get('subject_type') if isinstance(subject, dict) else None,
|
||||
subject_id=subject.get('subject_id') if isinstance(subject, dict) else None,
|
||||
input_summary=f'steering injected from {event.get("event_id")}',
|
||||
run_id=run_id,
|
||||
runner_id=session.get('runner_id') if isinstance(session, dict) else None,
|
||||
metadata={
|
||||
'steering': {
|
||||
'status': 'injected',
|
||||
'source_event_id': event.get('event_id'),
|
||||
'claimed_by_run_id': item.get('claimed_run_id') if isinstance(item, dict) else run_id,
|
||||
'claimed_runner_id': item.get('runner_id') if isinstance(item, dict) else None,
|
||||
'claimed_at': item.get('claimed_at') if isinstance(item, dict) else None,
|
||||
'pull_mode': str(mode or 'all'),
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
h.ap.logger.warning(
|
||||
f'Failed to write steering injection audit for run {run_id}: {exc}',
|
||||
exc_info=True,
|
||||
)
|
||||
return handler.ActionResponse.success(data={'items': items})
|
||||
|
||||
# ================= State APIs (run-scoped, policy-enforced) =================
|
||||
|
||||
@h.action(PluginToRuntimeAction.STATE_GET)
|
||||
async def state_get(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Get a state value from host-owned state store.
|
||||
|
||||
Requires run_id authorization and scope enabled by state_policy.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
scope = data.get('scope')
|
||||
key = data.get('key')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if not scope:
|
||||
return handler.ActionResponse.error(message='scope is required')
|
||||
|
||||
if not key:
|
||||
return handler.ActionResponse.error(message='key is required')
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'State get',
|
||||
api_capability='state',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
_state_context, scope_key, state_error = _resolve_state_scope(session, scope)
|
||||
if state_error:
|
||||
return state_error
|
||||
|
||||
# Get state from persistent store
|
||||
from ..agent.runner.persistent_state_store import get_persistent_state_store
|
||||
|
||||
store = get_persistent_state_store(h.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
value = await store.state_get(scope_key, key)
|
||||
return handler.ActionResponse.success(data={'value': value})
|
||||
except Exception as e:
|
||||
h.ap.logger.error(f'STATE_GET error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'State get error: {e}')
|
||||
|
||||
@h.action(PluginToRuntimeAction.STATE_SET)
|
||||
async def state_set(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Set a state value in host-owned state store.
|
||||
|
||||
Requires run_id authorization and scope enabled by state_policy.
|
||||
Value must be JSON-serializable and size-limited.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
scope = data.get('scope')
|
||||
key = data.get('key')
|
||||
value = data.get('value')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if not scope:
|
||||
return handler.ActionResponse.error(message='scope is required')
|
||||
|
||||
if not key:
|
||||
return handler.ActionResponse.error(message='key is required')
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'State set',
|
||||
api_capability='state',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
state_context, scope_key, state_error = _resolve_state_scope(session, scope)
|
||||
if state_error:
|
||||
return state_error
|
||||
|
||||
# Get additional context for DB insert
|
||||
runner_id = session.get('runner_id', '')
|
||||
binding_identity = state_context.get('binding_identity', 'unknown')
|
||||
|
||||
# Set state in persistent store
|
||||
from ..agent.runner.persistent_state_store import get_persistent_state_store
|
||||
|
||||
store = get_persistent_state_store(h.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
success, error = await store.state_set(
|
||||
scope_key=scope_key,
|
||||
state_key=key,
|
||||
value=value,
|
||||
runner_id=runner_id,
|
||||
binding_identity=binding_identity,
|
||||
scope=scope,
|
||||
context=state_context,
|
||||
logger=h.ap.logger,
|
||||
)
|
||||
|
||||
if not success:
|
||||
return handler.ActionResponse.error(message=error or 'Failed to set state')
|
||||
|
||||
return handler.ActionResponse.success(data={'success': True})
|
||||
except Exception as e:
|
||||
h.ap.logger.error(f'STATE_SET error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'State set error: {e}')
|
||||
|
||||
@h.action(PluginToRuntimeAction.STATE_DELETE)
|
||||
async def state_delete(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Delete a state value from host-owned state store.
|
||||
|
||||
Requires run_id authorization and scope enabled by state_policy.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
scope = data.get('scope')
|
||||
key = data.get('key')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if not scope:
|
||||
return handler.ActionResponse.error(message='scope is required')
|
||||
|
||||
if not key:
|
||||
return handler.ActionResponse.error(message='key is required')
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'State delete',
|
||||
api_capability='state',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
_state_context, scope_key, state_error = _resolve_state_scope(session, scope)
|
||||
if state_error:
|
||||
return state_error
|
||||
|
||||
# Delete state from persistent store
|
||||
from ..agent.runner.persistent_state_store import get_persistent_state_store
|
||||
|
||||
store = get_persistent_state_store(h.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
deleted = await store.state_delete(scope_key, key)
|
||||
return handler.ActionResponse.success(data={'success': deleted})
|
||||
except Exception as e:
|
||||
h.ap.logger.error(f'STATE_DELETE error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'State delete error: {e}')
|
||||
|
||||
@h.action(PluginToRuntimeAction.STATE_LIST)
|
||||
async def state_list(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""List state keys in a scope.
|
||||
|
||||
Requires run_id authorization and scope enabled by state_policy.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
scope = data.get('scope')
|
||||
prefix = data.get('prefix')
|
||||
limit = data.get('limit', 100)
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if not scope:
|
||||
return handler.ActionResponse.error(message='scope is required')
|
||||
|
||||
# Validate limit
|
||||
if not isinstance(limit, int) or limit <= 0:
|
||||
limit = 100
|
||||
limit = min(limit, 100) # Cap at 100
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
h.ap,
|
||||
'State list',
|
||||
api_capability='state',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
_state_context, scope_key, state_error = _resolve_state_scope(session, scope)
|
||||
if state_error:
|
||||
return state_error
|
||||
|
||||
# List state keys from persistent store
|
||||
from ..agent.runner.persistent_state_store import get_persistent_state_store
|
||||
|
||||
store = get_persistent_state_store(h.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
keys, has_more = await store.state_list(scope_key, prefix, limit)
|
||||
return handler.ActionResponse.success(
|
||||
data={
|
||||
'keys': keys,
|
||||
'has_more': has_more,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
h.ap.logger.error(f'STATE_LIST error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'State list error: {e}')
|
||||
File diff suppressed because it is too large
Load Diff
@@ -90,7 +90,7 @@ class TestStateAPIHandlerAuthorization:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
# Get the STATE_GET action handler (actions dict is keyed by action value string)
|
||||
@@ -111,7 +111,7 @@ class TestStateAPIHandlerAuthorization:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
@@ -146,7 +146,7 @@ class TestStateAPIHandlerAuthorization:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
@@ -182,7 +182,7 @@ class TestStateAPIHandlerAuthorization:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
@@ -219,7 +219,7 @@ class TestStateAPIHandlerAuthorization:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
@@ -255,7 +255,7 @@ class TestStateAPIHandlerAuthorization:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
@@ -292,7 +292,7 @@ class TestStateAPIHandlerAuthorization:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
@@ -340,7 +340,7 @@ class TestStateAPIFullFlowWithRealDB:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
# Verify session has correct state_context
|
||||
@@ -446,7 +446,7 @@ class TestStateHandlerReadsFromAuthorizationSnapshot:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_get_handler = handler.actions[PluginToRuntimeAction.STATE_GET.value]
|
||||
|
||||
@@ -490,7 +490,7 @@ class TestStateHandlerReadsFromAuthorizationSnapshot:
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
with patch('langbot.pkg.plugin.handler.get_session_registry', return_value=session_registry):
|
||||
with patch('langbot.pkg.plugin.agent_run_support.get_session_registry', return_value=session_registry):
|
||||
handler = RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
state_set_handler = handler.actions[PluginToRuntimeAction.STATE_SET.value]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user